Skip to content

Commit d054cfe

Browse files
committed
Avoid init_nccl for every steps.
1 parent 158d567 commit d054cfe

File tree

3 files changed

+48
-32
lines changed

3 files changed

+48
-32
lines changed

paddle/fluid/operators/nccl/nccl_gpu_common.cc

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,44 @@ limitations under the License. */
1616
#include "paddle/fluid/platform/gpu_info.h"
1717

1818
namespace paddle {
19-
namespace platform {} // namespace platform
19+
namespace platform {
20+
namespace {
21+
// TODO(panyx0718): Where to destroy them.
22+
std::unique_ptr<std::vector<ncclComm_t>> global_comms;
23+
std::unique_ptr<std::unordered_map<int, int>> comm_id_map;
24+
bool inited = false;
25+
size_t last_num_gpus = -1;
26+
}
27+
28+
int Communicator::GetCommId(int device_id) const {
29+
return comm_id_map->at(device_id);
30+
}
31+
32+
void Communicator::InitAll(const std::vector<int>& gpus) {
33+
if (inited && last_num_gpus == gpus.size()) {
34+
return;
35+
}
36+
last_num_gpus = gpus.size();
37+
if (global_comms) {
38+
for (size_t i = 0; i < global_comms->size(); ++i) {
39+
// FIXME(dzh) : PADDLE_ENFORCE return void
40+
dynload::ncclCommDestroy((*global_comms)[i]);
41+
}
42+
}
43+
global_comms.reset(new std::vector<ncclComm_t>());
44+
comm_id_map.reset(new std::unordered_map<int, int>());
45+
global_comms->resize(gpus.size());
46+
for (size_t i = 0; i < gpus.size(); ++i) {
47+
(*comm_id_map)[gpus[i]] = i;
48+
}
49+
PADDLE_ENFORCE(
50+
dynload::ncclCommInitAll(global_comms->data(), gpus.size(), gpus.data()));
51+
inited = true;
52+
}
53+
54+
const std::vector<ncclComm_t>& Communicator::comms() const {
55+
return *global_comms;
56+
}
57+
58+
} // namespace platform
2059
} // namespace paddle

paddle/fluid/operators/nccl/nccl_gpu_common.h

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -29,39 +29,16 @@ limitations under the License. */
2929

3030
namespace paddle {
3131
namespace platform {
32-
3332
constexpr int kInvalidGPUId = -1;
3433

3534
struct Communicator {
36-
std::vector<ncclComm_t> comms_;
37-
std::unordered_map<int, int> comm_id_map_;
38-
bool inited_;
39-
4035
Communicator() {}
4136

42-
int GetCommId(int device_id) const { return comm_id_map_.at(device_id); }
43-
44-
void InitAll(const std::vector<int>& gpus) {
45-
comms_.resize(gpus.size());
46-
inited_ = false;
47-
for (size_t i = 0; i < gpus.size(); ++i) {
48-
comm_id_map_[gpus[i]] = i;
49-
}
50-
PADDLE_ENFORCE(
51-
dynload::ncclCommInitAll(comms_.data(), gpus.size(), gpus.data()));
52-
inited_ = true;
53-
}
37+
int GetCommId(int device_id) const;
5438

55-
~Communicator() {
56-
if (inited_) {
57-
for (size_t i = 0; i < comms_.size(); ++i) {
58-
// FIXME(dzh) : PADDLE_ENFORCE return void
59-
dynload::ncclCommDestroy(comms_[i]);
60-
}
61-
}
62-
}
39+
void InitAll(const std::vector<int>& gpus);
6340

64-
DISABLE_COPY_AND_ASSIGN(Communicator);
41+
const std::vector<ncclComm_t>& comms() const;
6542
};
6643

6744
} // namespace platform

paddle/fluid/operators/nccl_op.cu.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
7878
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
7979
ins[i]->data<T>(), outs[i]->mutable_data<T>(ctx.GetPlace()),
8080
outs[i]->numel(), NCCLTypeWrapper<T>::type, reduction_op_,
81-
comm->comms_[idx], stream));
81+
comm->comms().at(idx), stream));
8282
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
8383

8484
VLOG(1) << "gpu : "
@@ -127,7 +127,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
127127
std::hash<std::string> hasher;
128128
for (size_t i = 0; i < ins.size(); ++i) {
129129
if (root == platform::kInvalidGPUId) {
130-
root = hasher(ins_names[i]) % comm->comms_.size();
130+
root = hasher(ins_names[i]) % comm->comms().size();
131131
}
132132
T* recvbuffer = nullptr;
133133
if (root == gpu_id) {
@@ -139,7 +139,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
139139

140140
PADDLE_ENFORCE(platform::dynload::ncclReduce(
141141
ins[i]->data<T>(), recvbuffer, ins[i]->numel(),
142-
NCCLTypeWrapper<T>::type, reduction_op_, root, comm->comms_[idx],
142+
NCCLTypeWrapper<T>::type, reduction_op_, root, comm->comms().at(idx),
143143
stream));
144144
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
145145

@@ -176,7 +176,7 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
176176
VLOG(1) << " before ncclBcast";
177177
PADDLE_ENFORCE(platform::dynload::ncclBcast(
178178
(void*)ins[i]->data<T>(), ins[i]->numel(), NCCLTypeWrapper<T>::type,
179-
root, comm->comms_[idx], stream));
179+
root, comm->comms().at(idx), stream));
180180
VLOG(1) << " after ncclBcast";
181181
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
182182

@@ -190,7 +190,7 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
190190

191191
PADDLE_ENFORCE(platform::dynload::ncclBcast(
192192
outs[i]->mutable_data<T>(ctx.GetPlace()), outs[i]->numel(),
193-
NCCLTypeWrapper<T>::type, root, comm->comms_[idx], stream));
193+
NCCLTypeWrapper<T>::type, root, comm->comms().at(idx), stream));
194194
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
195195

196196
VLOG(1) << "gpu : " << gpu_id << " finished Bcast. recv "

0 commit comments

Comments
 (0)