Skip to content

Commit c3c7b7b

Browse files
authored
Merge pull request #9928 from reyoung/feature/stablize_code
Use mutex to stablize ncclCtxMap
2 parents 35483a2 + 093d227 commit c3c7b7b

File tree

1 file changed

+16
-34
lines changed

1 file changed

+16
-34
lines changed

paddle/fluid/platform/nccl_helper.h

Lines changed: 16 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,19 @@ inline ncclDataType_t ToNCCLDataType(std::type_index type) {
3939

4040
class NCCLGroupGuard {
4141
public:
42+
static std::mutex &NCCLMutex() {
43+
static std::mutex mtx;
44+
return mtx;
45+
}
46+
4247
inline NCCLGroupGuard() {
43-
mutex().lock();
48+
NCCLMutex().lock();
4449
PADDLE_ENFORCE(dynload::ncclGroupStart());
4550
}
4651

4752
inline ~NCCLGroupGuard() {
4853
PADDLE_ENFORCE(dynload::ncclGroupEnd());
49-
mutex().unlock();
50-
}
51-
52-
private:
53-
static std::mutex &mutex() {
54-
static std::mutex mtx;
55-
return mtx;
54+
NCCLMutex().unlock();
5655
}
5756
};
5857

@@ -68,26 +67,6 @@ struct NCCLContext {
6867
int device_id() const {
6968
return boost::get<platform::CUDAPlace>(ctx_->GetPlace()).device;
7069
}
71-
72-
static void InitNCCLContext(std::unordered_map<int, NCCLContext> *contexts,
73-
const std::vector<platform::Place> &places) {
74-
std::vector<ncclComm_t> comms;
75-
std::vector<int> devs;
76-
comms.resize(contexts->size());
77-
devs.reserve(contexts->size());
78-
79-
for (auto &p : places) {
80-
devs.push_back(boost::get<platform::CUDAPlace>(p).device);
81-
}
82-
83-
PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
84-
&comms[0], static_cast<int>(contexts->size()), &devs[0]));
85-
86-
int i = 0;
87-
for (auto &dev_id : devs) {
88-
contexts->at(dev_id).comm_ = comms[i++];
89-
}
90-
}
9170
};
9271

9372
struct NCCLContextMap {
@@ -107,19 +86,22 @@ struct NCCLContextMap {
10786
"NCCL Context Map does not support contain two or more same device");
10887

10988
if (places.size() > 1) {
110-
std::vector<ncclComm_t> comms;
111-
comms.resize(order_.size());
112-
113-
PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
114-
&comms[0], static_cast<int>(order_.size()), &order_[0]));
115-
89+
std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]);
90+
{
91+
std::lock_guard<std::mutex> guard(NCCLGroupGuard::NCCLMutex());
92+
PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
93+
comms.get(), static_cast<int>(order_.size()), order_.data()));
94+
}
11695
int i = 0;
11796
for (auto &dev_id : order_) {
11897
contexts_.at(dev_id).comm_ = comms[i++];
11998
}
12099
}
121100
}
122101

102+
NCCLContextMap(const NCCLContextMap &other) = delete;
103+
NCCLContextMap &operator=(const NCCLContextMap &other) = delete;
104+
123105
CUDADeviceContext *DevCtx(int dev_id) const { return at(dev_id).ctx_.get(); }
124106

125107
CUDADeviceContext *DevCtx(platform::Place p) const {

0 commit comments

Comments
 (0)