@@ -28,10 +28,10 @@ namespace operators {
28
28
#define WARP_SIZE 32
29
29
30
30
template <typename T>
31
- __inline__ __device__ T warpReduceSum (T val) {
31
+ __inline__ __device__ T warpReduceSum (T val, unsigned lane_mask ) {
32
32
for (int mask = HALF_WARP; mask > 0 ; mask >>= 1 )
33
33
#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000
34
- val += __shfl_xor_sync (FINAL_MASK , val, mask, warpSize );
34
+ val += __shfl_xor_sync (lane_mask , val, mask, warpSize );
35
35
#else
36
36
val += __shfl_xor (val, mask, warpSize );
37
37
#endif
@@ -40,28 +40,30 @@ __inline__ __device__ T warpReduceSum(T val) {
40
40
41
41
/* Calculate the sum of all elements in a block */
42
42
template <typename T>
43
- __inline__ __device__ T blockReduceSum (T val) {
43
+ __inline__ __device__ T blockReduceSum (T val, unsigned mask ) {
44
44
static __shared__ T shared[WARP_SIZE];
45
45
int lane = threadIdx .x & 0x1f ;
46
46
int wid = threadIdx .x >> 5 ;
47
47
48
- val = warpReduceSum<T>(val);
48
+ val = warpReduceSum<T>(val, mask );
49
49
50
50
if (lane == 0 ) shared[wid] = val;
51
51
52
52
__syncthreads ();
53
53
54
- val = (threadIdx .x < (blockDim .x >> 5 )) ? shared[lane] : (T)(0 .0f );
55
- val = warpReduceSum<T>(val);
54
+ // align block_span to warpSize
55
+ int block_span = (blockDim .x + warpSize - 1 ) >> 5 ;
56
+ val = (threadIdx .x < block_span) ? shared[lane] : (T)(0 .0f );
57
+ val = warpReduceSum<T>(val, mask);
56
58
57
59
return val;
58
60
}
59
61
60
62
template <typename T>
61
- __inline__ __device__ T warpReduceMax (T val) {
63
+ __inline__ __device__ T warpReduceMax (T val, unsigned lane_mask ) {
62
64
for (int mask = HALF_WARP; mask > 0 ; mask >>= 1 )
63
65
#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000
64
- val = max (val, __shfl_xor_sync (FINAL_MASK , val, mask, warpSize ));
66
+ val = max (val, __shfl_xor_sync (lane_mask , val, mask, warpSize ));
65
67
#else
66
68
val = max (val, __shfl_xor (val, mask, warpSize ));
67
69
#endif
@@ -70,19 +72,21 @@ __inline__ __device__ T warpReduceMax(T val) {
70
72
71
73
/* Calculate the maximum of all elements in a block */
72
74
template <typename T>
73
- __inline__ __device__ T blockReduceMax (T val) {
75
+ __inline__ __device__ T blockReduceMax (T val, unsigned mask ) {
74
76
static __shared__ T shared[WARP_SIZE];
75
77
int lane = threadIdx .x & 0x1f ;
76
78
int wid = threadIdx .x >> 5 ;
77
79
78
- val = warpReduceMax (val);
80
+ val = warpReduceMax (val, mask );
79
81
80
82
if (lane == 0 ) shared[wid] = val;
81
83
82
84
__syncthreads ();
83
85
84
- val = (threadIdx .x < (blockDim .x >> 5 )) ? shared[lane] : -1e10f;
85
- val = warpReduceMax (val);
86
+ // align block_span to warpSize
87
+ int block_span = (blockDim .x + warpSize - 1 ) >> 5 ;
88
+ val = (threadIdx .x < block_span) ? shared[lane] : -1e10f;
89
+ val = warpReduceMax (val, mask);
86
90
87
91
return val;
88
92
}
@@ -190,7 +194,8 @@ template <typename T>
190
194
__global__ void softmax_kernel_with_eltadd (T *qk_buf_, const T *bias_qk_,
191
195
const int batch_size,
192
196
const int head_num,
193
- const int seq_len) {
197
+ const int seq_len,
198
+ const unsigned mask) {
194
199
int seq_id = blockIdx .x % seq_len;
195
200
int qk_offset = blockIdx .x * seq_len;
196
201
int bias_offset = blockIdx .x % (head_num * seq_len) * seq_len;
@@ -202,13 +207,15 @@ __global__ void softmax_kernel_with_eltadd(T *qk_buf_, const T *bias_qk_,
202
207
bias_qk_[threadIdx .x + bias_offset]))
203
208
: 0 .0f ;
204
209
float tmp = threadIdx .x < seq_len ? static_cast <float >(qk) : -1e20f;
205
- float max_val = blockReduceMax<float >(tmp);
210
+
211
+ float max_val = blockReduceMax<float >(tmp, mask);
212
+
206
213
if (threadIdx .x == 0 ) s_max = max_val;
207
214
__syncthreads ();
208
215
209
216
float qk_tmp =
210
217
threadIdx .x < seq_len ? __expf (static_cast <float >(tmp - s_max)) : 0 .0f ;
211
- float sum_val = blockReduceSum<float >(qk_tmp);
218
+ float sum_val = blockReduceSum<float >(qk_tmp, mask );
212
219
213
220
if (threadIdx .x == 0 ) {
214
221
s_sum = sum_val + 1e-6f ;
@@ -258,8 +265,9 @@ void MatMulWithHeadQK(const platform::CUDADeviceContext &context, int head_num,
258
265
int grid = m;
259
266
int block = k;
260
267
268
+ unsigned mask = block < 32 ? (((unsigned )1 << block) - 1 ) : FINAL_MASK;
261
269
softmax_kernel_with_eltadd<T><<<grid, block, 0 , stream>>> (
262
- qk_buf_, bias_qk, batch_size, head_num, seq_len);
270
+ qk_buf_, bias_qk, batch_size, head_num, seq_len, mask );
263
271
}
264
272
265
273
template <typename T>
@@ -331,7 +339,7 @@ void MultiHeadGPUCompute(const platform::CUDADeviceContext &dev_ctx,
331
339
auto stream = dev_ctx.stream ();
332
340
333
341
int grid = m;
334
- PADDLE_ENFORCE_LT (k, 1024 ,
342
+ PADDLE_ENFORCE_LE (k, 1024 ,
335
343
" Input head_number * size_per_head should <= 1024" );
336
344
int block = k <= 1024 ? k : 1024 ;
337
345
add_QKV<T><<<grid, block, 0 , stream>>> (Q, K, V, q_buf, k_buf, v_buf, bias_q,
0 commit comments