@@ -43,13 +43,12 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
43
43
void Compute (const framework::ExecutionContext& ctx) const override {
44
44
PADDLE_ENFORCE (platform::is_gpu_place (ctx.GetPlace ()),
45
45
" This kernel only runs on GPU device." );
46
-
47
- auto ins = ctx.MultiInput <LoDTensor>(" X" );
48
- auto outs = ctx.MultiOutput <LoDTensor>(" Out" );
49
-
46
+ auto * x = ctx.Input <LoDTensor>(" X" );
47
+ auto * out = ctx.Output <LoDTensor>(" Out" );
48
+ auto * comm = ctx.Input <Communicator>(" Communicator" );
50
49
std::string reduction = ctx.Attr <std::string>(" reduction" );
51
- ncclRedOp_t reduction_op_ = ncclSum;
52
50
51
+ ncclRedOp_t reduction_op_ = ncclSum;
53
52
if (reduction == " ncclMin" ) {
54
53
reduction_op_ = ncclMin;
55
54
} else if (reduction == " ncclMax" ) {
@@ -61,30 +60,19 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
61
60
} else {
62
61
PADDLE_THROW (" Invalid reduction. default ncclSum." );
63
62
}
64
-
65
- auto * comm = ctx.Input <Communicator>(" Communicator" );
66
-
67
- auto stream = ctx.cuda_device_context ().stream ();
68
-
69
63
// device id
70
64
int gpu_id = boost::get<platform::CUDAPlace>(ctx.GetPlace ()).GetDeviceId ();
71
65
int idx = comm->GetCommId (gpu_id);
72
-
73
- for (size_t i = 0 ; i < ins.size (); ++i) {
74
- VLOG (1 ) << " gpu : "
75
- << " invoke allreduce. send " << ins[i]->numel () << " recv "
76
- << outs[i]->numel ();
77
-
78
- PADDLE_ENFORCE (platform::dynload::ncclAllReduce (
79
- ins[i]->data <T>(), outs[i]->mutable_data <T>(ctx.GetPlace ()),
80
- outs[i]->numel (), NCCLTypeWrapper<T>::type, reduction_op_,
81
- comm->comms ().at (idx), stream));
82
- PADDLE_ENFORCE (cudaStreamSynchronize (stream));
83
-
84
- VLOG (1 ) << " gpu : "
85
- << " finished allreduce. send " << ins[i]->numel () << " recv "
86
- << outs[i]->numel ();
87
- }
66
+ VLOG (3 ) << " gpu : "
67
+ << " invoke allreduce. send " << x->numel () << " recv "
68
+ << out->numel ();
69
+ PADDLE_ENFORCE (platform::dynload::ncclAllReduce (
70
+ x->data <T>(), out->mutable_data <T>(ctx.GetPlace ()), out->numel (),
71
+ NCCLTypeWrapper<T>::type, reduction_op_, comm->comms ().at (idx),
72
+ ctx.cuda_device_context ().stream ()));
73
+ VLOG (3 ) << " gpu : "
74
+ << " finished allreduce. send " << x->numel () << " recv "
75
+ << out->numel ();
88
76
}
89
77
};
90
78
@@ -94,13 +82,13 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
94
82
void Compute (const framework::ExecutionContext& ctx) const override {
95
83
PADDLE_ENFORCE (platform::is_gpu_place (ctx.GetPlace ()),
96
84
" This kernel only runs on GPU device." );
97
-
98
- auto ins = ctx.MultiInput <LoDTensor>(" X " ); // x0, x1, x2
99
- auto outs = ctx.MultiOutput <LoDTensor >(" Out " );
100
-
85
+ auto x = ctx. Input <LoDTensor>( " X " ); // x0, x1, x2
86
+ auto out = ctx.Output <LoDTensor>(" Out " );
87
+ auto * comm = ctx.Input <Communicator >(" Communicator " );
88
+ int root = ctx. Attr < int >( " root " );
101
89
std::string reduction = ctx.Attr <std::string>(" reduction" );
102
- ncclRedOp_t reduction_op_ = ncclSum;
103
90
91
+ ncclRedOp_t reduction_op_ = ncclSum;
104
92
if (reduction == " ncclMin" ) {
105
93
reduction_op_ = ncclMin;
106
94
} else if (reduction == " ncclMax" ) {
@@ -112,40 +100,21 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
112
100
} else {
113
101
PADDLE_THROW (" Invalid reduction. default ncclSum." );
114
102
}
115
-
116
- int root = ctx.Attr <int >(" root" );
117
- auto * comm = ctx.Input <Communicator>(" Communicator" );
118
-
119
- auto stream = reinterpret_cast <const platform::CUDADeviceContext&>(
120
- ctx.device_context ())
121
- .stream ();
122
103
// device id
123
104
int gpu_id = boost::get<platform::CUDAPlace>(ctx.GetPlace ()).GetDeviceId ();
124
105
int idx = comm->GetCommId (gpu_id);
125
-
126
- auto ins_names = ctx.Inputs (" X" );
127
- std::hash<std::string> hasher;
128
- for (size_t i = 0 ; i < ins.size (); ++i) {
129
- if (root == platform::kInvalidGPUId ) {
130
- root = hasher (ins_names[i]) % comm->comms ().size ();
131
- }
132
- T* recvbuffer = nullptr ;
133
- if (root == gpu_id) {
134
- recvbuffer = outs[i]->mutable_data <T>(ctx.GetPlace ());
135
- }
136
-
137
- VLOG (1 ) << " gpu : " << gpu_id << " invoke reduce. send "
138
- << ins[i]->numel () << " recv " << outs[i]->numel ();
139
-
140
- PADDLE_ENFORCE (platform::dynload::ncclReduce (
141
- ins[i]->data <T>(), recvbuffer, ins[i]->numel (),
142
- NCCLTypeWrapper<T>::type, reduction_op_, root, comm->comms ().at (idx),
143
- stream));
144
- PADDLE_ENFORCE (cudaStreamSynchronize (stream));
145
-
146
- VLOG (1 ) << " gpu : " << gpu_id << " finished reduce. send "
147
- << ins[i]->numel () << " recv " << outs[i]->numel ();
106
+ T* recvbuffer = nullptr ;
107
+ if (root == gpu_id) {
108
+ recvbuffer = out->mutable_data <T>(ctx.GetPlace ());
148
109
}
110
+ VLOG (3 ) << " gpu : " << gpu_id << " invoke reduce. send " << x->numel ()
111
+ << " recv " << out->numel ();
112
+ PADDLE_ENFORCE (platform::dynload::ncclReduce (
113
+ x->data <T>(), recvbuffer, x->numel (), NCCLTypeWrapper<T>::type,
114
+ reduction_op_, root, comm->comms ().at (idx),
115
+ ctx.cuda_device_context ().stream ()));
116
+ VLOG (3 ) << " gpu : " << gpu_id << " finished reduce. send " << x->numel ()
117
+ << " recv " << out->numel ();
149
118
}
150
119
};
151
120
@@ -155,47 +124,27 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
155
124
void Compute (const framework::ExecutionContext& ctx) const override {
156
125
PADDLE_ENFORCE (platform::is_gpu_place (ctx.GetPlace ()),
157
126
" This kernel only runs on GPU device." );
158
-
159
127
int root = ctx.Attr <int >(" root" );
160
-
161
128
auto * comm = ctx.Input <Communicator>(" Communicator" );
162
-
163
- auto stream = reinterpret_cast <const platform::CUDADeviceContext&>(
164
- ctx.device_context ())
165
- .stream ();
166
129
// device id
167
130
int gpu_id = boost::get<platform::CUDAPlace>(ctx.GetPlace ()).GetDeviceId ();
168
131
int idx = comm->GetCommId (gpu_id);
169
-
170
132
if (idx == root) {
171
- auto ins = ctx.MultiInput <LoDTensor>(" X" );
172
- for (size_t i = 0 ; i < ins.size (); ++i) {
173
- VLOG (1 ) << " gpu : " << gpu_id << " invoke Bcast. send "
174
- << ins[i]->numel ();
175
-
176
- VLOG (1 ) << " before ncclBcast" ;
177
- PADDLE_ENFORCE (platform::dynload::ncclBcast (
178
- (void *)ins[i]->data <T>(), ins[i]->numel (), NCCLTypeWrapper<T>::type,
179
- root, comm->comms ().at (idx), stream));
180
- VLOG (1 ) << " after ncclBcast" ;
181
- PADDLE_ENFORCE (cudaStreamSynchronize (stream));
182
-
183
- VLOG (1 ) << " gpu : " << gpu_id << " finished Bcast." ;
184
- }
133
+ auto * x = ctx.Input <LoDTensor>(" X" );
134
+ VLOG (3 ) << " gpu : " << gpu_id << " invoke Bcast. send " << x->numel ();
135
+ PADDLE_ENFORCE (platform::dynload::ncclBcast (
136
+ (void *)x->data <T>(), x->numel (), NCCLTypeWrapper<T>::type, root,
137
+ comm->comms ().at (idx), ctx.cuda_device_context ().stream ()));
138
+ VLOG (3 ) << " gpu : " << gpu_id << " finished Bcast." ;
185
139
} else {
186
- auto outs = ctx.MultiOutput <LoDTensor>(" Out" );
187
- for (size_t i = 0 ; i < outs.size (); ++i) {
188
- VLOG (1 ) << " gpu : " << gpu_id << " invoke Bcast. recv buffer "
189
- << framework::product (outs[i]->dims ());
190
-
191
- PADDLE_ENFORCE (platform::dynload::ncclBcast (
192
- outs[i]->mutable_data <T>(ctx.GetPlace ()), outs[i]->numel (),
193
- NCCLTypeWrapper<T>::type, root, comm->comms ().at (idx), stream));
194
- PADDLE_ENFORCE (cudaStreamSynchronize (stream));
195
-
196
- VLOG (1 ) << " gpu : " << gpu_id << " finished Bcast. recv "
197
- << outs[i]->numel ();
198
- }
140
+ auto * out = ctx.Output <LoDTensor>(" Out" );
141
+ VLOG (3 ) << " gpu : " << gpu_id << " invoke Bcast. recv buffer "
142
+ << framework::product (out->dims ());
143
+ PADDLE_ENFORCE (platform::dynload::ncclBcast (
144
+ out->mutable_data <T>(ctx.GetPlace ()), out->numel (),
145
+ NCCLTypeWrapper<T>::type, root, comm->comms ().at (idx),
146
+ ctx.cuda_device_context ().stream ()));
147
+ VLOG (3 ) << " gpu : " << gpu_id << " finished Bcast. recv " << out->numel ();
199
148
}
200
149
}
201
150
};
0 commit comments