Skip to content

Commit c64190e

Browse files
committed
Polish NCCLHelper
1 parent 7483555 commit c64190e

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

paddle/fluid/platform/nccl_helper.h

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ struct NCCLContext {
6161
ncclComm_t comm_;
6262

6363
explicit NCCLContext(int dev_id)
64-
: ctx_(new CUDADeviceContext(CUDAPlace(dev_id))) {}
64+
: ctx_(new CUDADeviceContext(CUDAPlace(dev_id))), comm_{nullptr} {}
6565

6666
cudaStream_t stream() const { return ctx_->stream(); }
6767

@@ -95,6 +95,7 @@ struct NCCLContextMap {
9595
std::vector<int> order_;
9696

9797
explicit NCCLContextMap(const std::vector<platform::Place> &places) {
98+
PADDLE_ENFORCE(!places.empty());
9899
order_.reserve(places.size());
99100
for (auto &p : places) {
100101
int dev_id = boost::get<CUDAPlace>(p).device;
@@ -105,15 +106,17 @@ struct NCCLContextMap {
105106
order_.size(), contexts_.size(),
106107
"NCCL Context Map does not support contain two or more same device");
107108

108-
std::vector<ncclComm_t> comms;
109-
comms.resize(order_.size());
109+
if (places.size() > 1) {
110+
std::vector<ncclComm_t> comms;
111+
comms.resize(order_.size());
110112

111-
PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
112-
&comms[0], static_cast<int>(order_.size()), &order_[0]));
113+
PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
114+
&comms[0], static_cast<int>(order_.size()), &order_[0]));
113115

114-
int i = 0;
115-
for (auto &dev_id : order_) {
116-
contexts_.at(dev_id).comm_ = comms[i++];
116+
int i = 0;
117+
for (auto &dev_id : order_) {
118+
contexts_.at(dev_id).comm_ = comms[i++];
119+
}
117120
}
118121
}
119122

0 commit comments

Comments
 (0)