Skip to content

Commit 62fed4c

Browse files
chengduodzhwinter
authored andcommitted
fix __shfl_down (#10362)
1 parent 3000e99 commit 62fed4c

File tree

3 files changed

+35
-17
lines changed

3 files changed

+35
-17
lines changed

paddle/cuda/include/hl_base.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,11 @@ extern __thread cudaStream_t default_stream;
229229

230230
// __shfl has been deprecated as of CUDA 9.0.
231231
#if CUDA_VERSION < 9000
232+
template <typename T>
233+
__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) {
234+
return __shfl_down(val, delta);
235+
}
236+
232237
template <typename T>
233238
__forceinline__ __device__ T
234239
__shfl_sync(unsigned, T val, int src_line, int width) {

paddle/fluid/operators/row_conv_op.cu

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,10 @@ __global__ void RowConvGradFilterImproved(const T *in, const T *dout,
189189
}
190190
__syncthreads();
191191

192+
// NOTE(zcd): temporary solution
193+
unsigned mask = 0u;
194+
CREATE_SHFL_MASK(mask, true);
195+
192196
for (int i = 0; i < num_sequence; i++) {
193197
int start = static_cast<int>(batch_indices[i]);
194198
int end = static_cast<int>(batch_indices[i + 1]);
@@ -220,7 +224,7 @@ __global__ void RowConvGradFilterImproved(const T *in, const T *dout,
220224

221225
for (int offset = 16; offset > 0;
222226
offset = offset / 2) { // blockDim.x is 32.
223-
val += platform::__shfl_down_sync(0, val, offset);
227+
val += platform::__shfl_down_sync(mask, val, offset);
224228
}
225229
__syncthreads();
226230

@@ -251,6 +255,10 @@ __global__ void RowConvGradFilter(const T *in, const T *dout, int num_sequence,
251255
T *sh_in = mem;
252256
T *sh_dout = &mem[block_x * block_y];
253257

258+
// NOTE(zcd): temporary solution
259+
unsigned mask = 0u;
260+
CREATE_SHFL_MASK(mask, true);
261+
254262
for (int i = 0; i < num_sequence; i++) {
255263
int start = static_cast<int>(batch_indices[i]);
256264
int end = static_cast<int>(batch_indices[i + 1]);
@@ -276,7 +284,7 @@ __global__ void RowConvGradFilter(const T *in, const T *dout, int num_sequence,
276284

277285
for (int offset = 16; offset > 0;
278286
offset = offset / 2) { // blockDim.x is 32.
279-
val += platform::__shfl_down_sync(0, val, offset);
287+
val += platform::__shfl_down_sync(mask, val, offset);
280288
}
281289
__syncthreads();
282290

paddle/function/RowConvOpGpu.cu

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "RowConvOp.h"
16-
#include "hl_base.h"
15+
#include "paddle/cuda/include/hl_base.h"
16+
#include "paddle/function/RowConvOp.h"
1717

1818
namespace paddle {
1919

@@ -94,7 +94,7 @@ __global__ void KeRowConv2(real* y,
9494
}
9595

9696
template <>
97-
void RowConv<DEVICE_TYPE_GPU>(GpuMatrix& out,
97+
void RowConv<DEVICE_TYPE_GPU>(GpuMatrix& out, // NOLINT
9898
const GpuMatrix& in,
9999
const GpuMatrix& filter,
100100
const GpuIVector& seq) {
@@ -144,6 +144,10 @@ __global__ void KeRowConvBwWeight(real* dw,
144144
}
145145
__syncthreads();
146146

147+
// NOTE(zcd): temporary solution
148+
unsigned mask = 0u;
149+
CREATE_SHFL_MASK(mask, true);
150+
147151
for (int i = 0; i < numSeq; ++i) {
148152
const int start = starts[i];
149153
const int end = starts[i + 1];
@@ -170,11 +174,10 @@ __global__ void KeRowConvBwWeight(real* dw,
170174
real val = sh_x[tidy][tidx] * sh_dy[tidy][tidx + context - 1 - t];
171175
__syncthreads();
172176
// warp size and blockDim.x is 32.
173-
val += __shfl_down(val, 16);
174-
val += __shfl_down(val, 8);
175-
val += __shfl_down(val, 4);
176-
val += __shfl_down(val, 2);
177-
val += __shfl_down(val, 1);
177+
178+
for (int offset = 16; offset > 0; offset /= 2)
179+
val += __shfl_down_sync(mask, val, offset);
180+
178181
__syncthreads();
179182
if (tidx == 0) {
180183
sh_dw[t][tidy] += val;
@@ -205,6 +208,10 @@ __global__ void KeRowConvBwWeight2(real* dw,
205208
__shared__ real sh_x[BLOCK_H][BLOCK_W];
206209
__shared__ real sh_dy[BLOCK_H][BLOCK_W];
207210

211+
// NOTE(zcd): temporary solution
212+
unsigned mask = 0u;
213+
CREATE_SHFL_MASK(mask, true);
214+
208215
for (int i = 0; i < numSeq; ++i) {
209216
const int start = starts[i];
210217
const int end = starts[i + 1];
@@ -230,11 +237,9 @@ __global__ void KeRowConvBwWeight2(real* dw,
230237
real val = sh_x[tidy][tidx] * sh_dy[tidy][tidx];
231238
__syncthreads();
232239
// warp size and blockDim.x is 32.
233-
val += __shfl_down(val, 16);
234-
val += __shfl_down(val, 8);
235-
val += __shfl_down(val, 4);
236-
val += __shfl_down(val, 2);
237-
val += __shfl_down(val, 1);
240+
for (int offset = 16; offset > 0; offset /= 2)
241+
val += __shfl_down_sync(mask, val, offset);
242+
238243
__syncthreads();
239244

240245
if (tidx == 0 && (gidx + tidy) < width) {
@@ -323,8 +328,8 @@ template <>
323328
void RowConvGrad<DEVICE_TYPE_GPU>(const GpuMatrix& outG,
324329
const GpuMatrix& in,
325330
const GpuMatrix& filter,
326-
GpuMatrix& inG,
327-
GpuMatrix& filterG,
331+
GpuMatrix& inG, // NOLINT
332+
GpuMatrix& filterG, // NOLINT
328333
const GpuIVector& seq) {
329334
const size_t numSeq = seq.getSize() - 1;
330335
const size_t contextLength = filter.getHeight();

0 commit comments

Comments
 (0)