@@ -39,20 +39,19 @@ inline ncclDataType_t ToNCCLDataType(std::type_index type) {
39
39
40
40
class NCCLGroupGuard {
41
41
public:
42
+ static std::mutex &NCCLMutex () {
43
+ static std::mutex mtx;
44
+ return mtx;
45
+ }
46
+
42
47
inline NCCLGroupGuard () {
43
- mutex ().lock ();
48
+ NCCLMutex ().lock ();
44
49
PADDLE_ENFORCE (dynload::ncclGroupStart ());
45
50
}
46
51
47
52
inline ~NCCLGroupGuard () {
48
53
PADDLE_ENFORCE (dynload::ncclGroupEnd ());
49
- mutex ().unlock ();
50
- }
51
-
52
- private:
53
- static std::mutex &mutex () {
54
- static std::mutex mtx;
55
- return mtx;
54
+ NCCLMutex ().unlock ();
56
55
}
57
56
};
58
57
@@ -68,26 +67,6 @@ struct NCCLContext {
68
67
int device_id () const {
69
68
return boost::get<platform::CUDAPlace>(ctx_->GetPlace ()).device ;
70
69
}
71
-
72
- static void InitNCCLContext (std::unordered_map<int , NCCLContext> *contexts,
73
- const std::vector<platform::Place> &places) {
74
- std::vector<ncclComm_t> comms;
75
- std::vector<int > devs;
76
- comms.resize (contexts->size ());
77
- devs.reserve (contexts->size ());
78
-
79
- for (auto &p : places) {
80
- devs.push_back (boost::get<platform::CUDAPlace>(p).device );
81
- }
82
-
83
- PADDLE_ENFORCE (platform::dynload::ncclCommInitAll (
84
- &comms[0 ], static_cast <int >(contexts->size ()), &devs[0 ]));
85
-
86
- int i = 0 ;
87
- for (auto &dev_id : devs) {
88
- contexts->at (dev_id).comm_ = comms[i++];
89
- }
90
- }
91
70
};
92
71
93
72
struct NCCLContextMap {
@@ -107,19 +86,22 @@ struct NCCLContextMap {
107
86
" NCCL Context Map does not support contain two or more same device" );
108
87
109
88
if (places.size () > 1 ) {
110
- std::vector <ncclComm_t> comms;
111
- comms. resize (order_. size ());
112
-
113
- PADDLE_ENFORCE (platform::dynload::ncclCommInitAll (
114
- & comms[ 0 ] , static_cast <int >(order_.size ()), & order_[ 0 ] ));
115
-
89
+ std::unique_ptr <ncclComm_t[] > comms ( new ncclComm_t[order_. size ()]) ;
90
+ {
91
+ std::lock_guard<std::mutex> guard ( NCCLGroupGuard::NCCLMutex ());
92
+ PADDLE_ENFORCE (platform::dynload::ncclCommInitAll (
93
+ comms. get () , static_cast <int >(order_.size ()), order_. data () ));
94
+ }
116
95
int i = 0 ;
117
96
for (auto &dev_id : order_) {
118
97
contexts_.at (dev_id).comm_ = comms[i++];
119
98
}
120
99
}
121
100
}
122
101
102
+ NCCLContextMap (const NCCLContextMap &other) = delete ;
103
+ NCCLContextMap &operator =(const NCCLContextMap &other) = delete ;
104
+
123
105
CUDADeviceContext *DevCtx (int dev_id) const { return at (dev_id).ctx_ .get (); }
124
106
125
107
CUDADeviceContext *DevCtx (platform::Place p) const {
0 commit comments