Skip to content

Commit 2f0f10b

Browse files
[cherry-pick] Fix multihead op bug. (#20783) (#21438)
The op should handle k=1024 Fix seq_len < warpsize error. test=develop Signed-off-by: zhaoyuchen <[email protected]>
1 parent 873b32d commit 2f0f10b

File tree

2 files changed

+26
-18
lines changed

2 files changed

+26
-18
lines changed

paddle/fluid/operators/multihead_matmul_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ MultiHeadMatMul Operator.
134134
This op is used for optimize multi head calculation in ernie model.
135135
Not suggest to use in other case except has same structure as ernie.
136136
137-
Example of matrix multiplication with head_number of H
137+
Example of matrix multiplication with head_number of B
138138
- X: [B, M, K], Y: [B, K, N] => Out: [B, M, N]
139139
140140
Both the input `Q` and `K` can carry the LoD (Level of Details) information,

paddle/fluid/operators/multihead_matmul_op.cu

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ namespace operators {
2828
#define WARP_SIZE 32
2929

3030
template <typename T>
31-
__inline__ __device__ T warpReduceSum(T val) {
31+
__inline__ __device__ T warpReduceSum(T val, unsigned lane_mask) {
3232
for (int mask = HALF_WARP; mask > 0; mask >>= 1)
3333
#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000
34-
val += __shfl_xor_sync(FINAL_MASK, val, mask, warpSize);
34+
val += __shfl_xor_sync(lane_mask, val, mask, warpSize);
3535
#else
3636
val += __shfl_xor(val, mask, warpSize);
3737
#endif
@@ -40,28 +40,30 @@ __inline__ __device__ T warpReduceSum(T val) {
4040

4141
/* Calculate the sum of all elements in a block */
4242
template <typename T>
43-
__inline__ __device__ T blockReduceSum(T val) {
43+
__inline__ __device__ T blockReduceSum(T val, unsigned mask) {
4444
static __shared__ T shared[WARP_SIZE];
4545
int lane = threadIdx.x & 0x1f;
4646
int wid = threadIdx.x >> 5;
4747

48-
val = warpReduceSum<T>(val);
48+
val = warpReduceSum<T>(val, mask);
4949

5050
if (lane == 0) shared[wid] = val;
5151

5252
__syncthreads();
5353

54-
val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : (T)(0.0f);
55-
val = warpReduceSum<T>(val);
54+
// align block_span to warpSize
55+
int block_span = (blockDim.x + warpSize - 1) >> 5;
56+
val = (threadIdx.x < block_span) ? shared[lane] : (T)(0.0f);
57+
val = warpReduceSum<T>(val, mask);
5658

5759
return val;
5860
}
5961

6062
template <typename T>
61-
__inline__ __device__ T warpReduceMax(T val) {
63+
__inline__ __device__ T warpReduceMax(T val, unsigned lane_mask) {
6264
for (int mask = HALF_WARP; mask > 0; mask >>= 1)
6365
#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000
64-
val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, warpSize));
66+
val = max(val, __shfl_xor_sync(lane_mask, val, mask, warpSize));
6567
#else
6668
val = max(val, __shfl_xor(val, mask, warpSize));
6769
#endif
@@ -70,19 +72,21 @@ __inline__ __device__ T warpReduceMax(T val) {
7072

7173
/* Calculate the maximum of all elements in a block */
7274
template <typename T>
73-
__inline__ __device__ T blockReduceMax(T val) {
75+
__inline__ __device__ T blockReduceMax(T val, unsigned mask) {
7476
static __shared__ T shared[WARP_SIZE];
7577
int lane = threadIdx.x & 0x1f;
7678
int wid = threadIdx.x >> 5;
7779

78-
val = warpReduceMax(val);
80+
val = warpReduceMax(val, mask);
7981

8082
if (lane == 0) shared[wid] = val;
8183

8284
__syncthreads();
8385

84-
val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : -1e10f;
85-
val = warpReduceMax(val);
86+
// align block_span to warpSize
87+
int block_span = (blockDim.x + warpSize - 1) >> 5;
88+
val = (threadIdx.x < block_span) ? shared[lane] : -1e10f;
89+
val = warpReduceMax(val, mask);
8690

8791
return val;
8892
}
@@ -190,7 +194,8 @@ template <typename T>
190194
__global__ void softmax_kernel_with_eltadd(T *qk_buf_, const T *bias_qk_,
191195
const int batch_size,
192196
const int head_num,
193-
const int seq_len) {
197+
const int seq_len,
198+
const unsigned mask) {
194199
int seq_id = blockIdx.x % seq_len;
195200
int qk_offset = blockIdx.x * seq_len;
196201
int bias_offset = blockIdx.x % (head_num * seq_len) * seq_len;
@@ -202,13 +207,15 @@ __global__ void softmax_kernel_with_eltadd(T *qk_buf_, const T *bias_qk_,
202207
bias_qk_[threadIdx.x + bias_offset]))
203208
: 0.0f;
204209
float tmp = threadIdx.x < seq_len ? static_cast<float>(qk) : -1e20f;
205-
float max_val = blockReduceMax<float>(tmp);
210+
211+
float max_val = blockReduceMax<float>(tmp, mask);
212+
206213
if (threadIdx.x == 0) s_max = max_val;
207214
__syncthreads();
208215

209216
float qk_tmp =
210217
threadIdx.x < seq_len ? __expf(static_cast<float>(tmp - s_max)) : 0.0f;
211-
float sum_val = blockReduceSum<float>(qk_tmp);
218+
float sum_val = blockReduceSum<float>(qk_tmp, mask);
212219

213220
if (threadIdx.x == 0) {
214221
s_sum = sum_val + 1e-6f;
@@ -258,8 +265,9 @@ void MatMulWithHeadQK(const platform::CUDADeviceContext &context, int head_num,
258265
int grid = m;
259266
int block = k;
260267

268+
unsigned mask = block < 32 ? (((unsigned)1 << block) - 1) : FINAL_MASK;
261269
softmax_kernel_with_eltadd<T><<<grid, block, 0, stream>>>(
262-
qk_buf_, bias_qk, batch_size, head_num, seq_len);
270+
qk_buf_, bias_qk, batch_size, head_num, seq_len, mask);
263271
}
264272

265273
template <typename T>
@@ -331,7 +339,7 @@ void MultiHeadGPUCompute(const platform::CUDADeviceContext &dev_ctx,
331339
auto stream = dev_ctx.stream();
332340

333341
int grid = m;
334-
PADDLE_ENFORCE_LT(k, 1024,
342+
PADDLE_ENFORCE_LE(k, 1024,
335343
"Input head_number * size_per_head should <= 1024");
336344
int block = k <= 1024 ? k : 1024;
337345
add_QKV<T><<<grid, block, 0, stream>>>(Q, K, V, q_buf, k_buf, v_buf, bias_q,

0 commit comments

Comments
 (0)