Skip to content

Commit 92b8810

Browse files
CUDA: skip masked KV slices for all FA kernels (#14924)
1 parent 00131d6 commit 92b8810

File tree

9 files changed

+120
-56
lines changed

9 files changed

+120
-56
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,20 @@ static __global__ void reduce_rows_f32(const float * x, float * dst, const int n
432432
dst[row] = norm ? sum / ncols : sum;
433433
}
434434

435+
template<int width = WARP_SIZE>
436+
static __device__ __forceinline__ int warp_reduce_all(int x) {
437+
#ifdef GGML_USE_HIP
438+
#pragma unroll
439+
for (int offset = width/2; offset > 0; offset >>= 1) {
440+
x = x && __shfl_xor_sync(0xffffffff, x, offset, width);
441+
}
442+
return x;
443+
#else
444+
static_assert(width == WARP_SIZE, "width != WARP_SIZE not implemented");
445+
return __all_sync(0xffffffff, x);
446+
#endif // GGML_USE_HIP
447+
}
448+
435449
template<int width = WARP_SIZE>
436450
static __device__ __forceinline__ float warp_reduce_max(float x) {
437451
#pragma unroll

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ typedef void (* fattn_kernel_t)(
1515
const char * __restrict__ K,
1616
const char * __restrict__ V,
1717
const char * __restrict__ mask,
18+
const int * __restrict__ KV_max,
1819
float * __restrict__ dst,
1920
float2 * __restrict__ dst_meta,
2021
const float scale,
@@ -500,6 +501,55 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
500501
nullptr;
501502
}
502503

504+
template <int ncols1>
505+
__launch_bounds__(FATTN_KQ_STRIDE/2, 1)
506+
static __global__ void flash_attn_mask_to_KV_max(
507+
const half2 * __restrict__ mask, int * __restrict__ KV_max, const int ne30, const int s31, const int s33) {
508+
const int ne31 = gridDim.x;
509+
const int tid = threadIdx.x;
510+
const int sequence = blockIdx.y;
511+
const int jt = blockIdx.x;
512+
513+
mask += sequence*s33 + jt*ncols1*s31;
514+
515+
__shared__ int buf_iw[WARP_SIZE];
516+
if (tid < WARP_SIZE) {
517+
buf_iw[tid] = 1;
518+
}
519+
__syncthreads();
520+
521+
int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE;
522+
for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) {
523+
int all_inf = 1;
524+
525+
#pragma unroll
526+
for (int j = 0; j < ncols1; ++j) {
527+
const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]);
528+
all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y));
529+
}
530+
531+
all_inf = warp_reduce_all(all_inf);
532+
if (tid % WARP_SIZE == 0) {
533+
buf_iw[tid / WARP_SIZE] = all_inf;
534+
}
535+
__syncthreads();
536+
all_inf = buf_iw[tid % WARP_SIZE];
537+
__syncthreads();
538+
all_inf = warp_reduce_all(all_inf);
539+
540+
if (!all_inf) {
541+
KV_max_sj += FATTN_KQ_STRIDE;
542+
break;
543+
}
544+
}
545+
546+
if (threadIdx.x != 0) {
547+
return;
548+
}
549+
550+
KV_max[sequence*ne31 + jt] = KV_max_sj;
551+
}
552+
503553
template<int D, int ncols1, int ncols2> // D == head size
504554
__launch_bounds__(D, 1)
505555
static __global__ void flash_attn_stream_k_fixup(
@@ -711,6 +761,7 @@ void launch_fattn(
711761

712762
ggml_cuda_pool_alloc<half> K_f16(pool);
713763
ggml_cuda_pool_alloc<half> V_f16(pool);
764+
ggml_cuda_pool_alloc<int> KV_max(pool);
714765
ggml_cuda_pool_alloc<float> dst_tmp(pool);
715766
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
716767

@@ -779,11 +830,30 @@ void launch_fattn(
779830
V_data = (char *) V_f16.ptr;
780831
}
781832

782-
int parallel_blocks = 1;
783-
784833
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
785834
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
786835

836+
// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
837+
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
838+
// multiple sequences of possibly different lengths.
839+
if (mask && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
840+
const int s31 = mask->nb[1] / sizeof(half2);
841+
const int s33 = mask->nb[3] / sizeof(half2);
842+
843+
const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
844+
const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1);
845+
846+
const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y;
847+
const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
848+
849+
KV_max.alloc(ne_KV_max);
850+
flash_attn_mask_to_KV_max<ncols1><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
851+
((const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33);
852+
CUDA_CHECK(cudaGetLastError());
853+
}
854+
855+
int parallel_blocks = 1;
856+
787857
const dim3 block_dim(warp_size, nwarps, 1);
788858
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
789859
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
@@ -870,6 +940,7 @@ void launch_fattn(
870940
K_data,
871941
V_data,
872942
mask ? ((const char *) mask->data) : nullptr,
943+
KV_max.ptr,
873944
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
874945
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
875946
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
392392
}
393393
}
394394

395-
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
395+
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles,
396+
bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
396397
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
397398
const float2 * const __restrict__ Q_f2,
398399
const half2 * const __restrict__ K_h2,
@@ -922,7 +923,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
922923
}
923924

924925
// Iterate over ne11 == previous tokens:
925-
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
926+
int kb0 = kb0_start;
927+
for (; kb0 < kb0_stop-1; ++kb0) {
926928
constexpr bool last_iter = false;
927929
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
928930
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
@@ -932,7 +934,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
932934
constexpr bool last_iter = true;
933935
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
934936
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
935-
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
937+
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
936938
}
937939

938940
// With multi-stage loading there is no __syncthreads at the end of the iter,
@@ -1204,6 +1206,7 @@ static __global__ void flash_attn_ext_f16(
12041206
const char * __restrict__ K,
12051207
const char * __restrict__ V,
12061208
const char * __restrict__ mask,
1209+
const int * __restrict__ KV_max,
12071210
float * __restrict__ dst,
12081211
float2 * __restrict__ dst_meta,
12091212
const float scale,
@@ -1280,7 +1283,11 @@ static __global__ void flash_attn_ext_f16(
12801283
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
12811284

12821285
const int kb0_start_kernel = kb0_start * kb_niter;
1283-
const int kb0_stop_kernel = kb0_stop * kb_niter;
1286+
int kb0_stop_kernel = kb0_stop * kb_niter;
1287+
1288+
if (KV_max) {
1289+
kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa);
1290+
}
12841291

12851292
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
12861293
if (kb0_start == 0) {
@@ -1321,7 +1328,11 @@ static __global__ void flash_attn_ext_f16(
13211328
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
13221329

13231330
const int kb0_start_kernel = kb0_start * kb_niter;
1324-
const int kb0_stop_kernel = kb0_stop * kb_niter;
1331+
int kb0_stop_kernel = kb0_stop * kb_niter;
1332+
1333+
if (KV_max) {
1334+
kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa);
1335+
}
13251336

13261337
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
13271338
constexpr bool needs_fixup = false;

ggml/src/ggml-cuda/fattn-tile-f16.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ static __global__ void flash_attn_tile_ext_f16(
1313
const char * __restrict__ K,
1414
const char * __restrict__ V,
1515
const char * __restrict__ mask,
16+
const int * __restrict__ KV_max,
1617
float * __restrict__ dst,
1718
float2 * __restrict__ dst_meta,
1819
const float scale,
@@ -90,7 +91,8 @@ static __global__ void flash_attn_tile_ext_f16(
9091

9192
__syncthreads();
9293

93-
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) {
94+
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
95+
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) {
9496
// Calculate KQ tile and keep track of new maximum KQ values:
9597

9698
half kqmax_new[ncols/nwarps];

ggml/src/ggml-cuda/fattn-tile-f32.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ static __global__ void flash_attn_tile_ext_f32(
1313
const char * __restrict__ K,
1414
const char * __restrict__ V,
1515
const char * __restrict__ mask,
16+
const int * __restrict__ KV_max,
1617
float * __restrict__ dst,
1718
float2 * __restrict__ dst_meta,
1819
const float scale,
@@ -99,7 +100,8 @@ static __global__ void flash_attn_tile_ext_f32(
99100

100101
__syncthreads();
101102

102-
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) {
103+
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
104+
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) {
103105
// Calculate KQ tile and keep track of new maximum KQ values:
104106

105107
float kqmax_new[ncols/nwarps];

ggml/src/ggml-cuda/fattn-vec-f16.cuh

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f16(
1616
const char * __restrict__ K,
1717
const char * __restrict__ V,
1818
const char * __restrict__ mask,
19+
const int * __restrict__ KV_max,
1920
float * __restrict__ dst,
2021
float2 * __restrict__ dst_meta,
2122
const float scale,
@@ -177,10 +178,11 @@ static __global__ void flash_attn_vec_ext_f16(
177178

178179
half2 VKQ[ncols] = {{0.0f, 0.0f}};
179180

181+
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
180182
K += blockIdx.y*D * nb11;
181183
V += blockIdx.y*D * nb21;
182184
maskh += blockIdx.y*D;
183-
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D,
185+
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*D,
184186
// Increment pointers after each loop:
185187
K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
186188

@@ -191,29 +193,7 @@ static __global__ void flash_attn_vec_ext_f16(
191193
for (int j = 0; j < ncols; ++j) {
192194
maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + tid];
193195
}
194-
195196
__syncthreads();
196-
197-
// When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.
198-
// In such cases, skip the KV slice.
199-
// On AMD __all_sync would not work correctly because it assumes a warp size of 64.
200-
#ifndef GGML_USE_HIP
201-
bool skip = true;
202-
#pragma unroll
203-
for (int j = 0; j < ncols; ++j) {
204-
#pragma unroll
205-
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
206-
const int i = i0 + threadIdx.x;
207-
208-
const float2 tmp = __half22float2(((const half2 *) maskh_shared)[j*(D/2) + i]);
209-
skip = skip && isinf(tmp.x) && isinf(tmp.y);
210-
}
211-
}
212-
if (__all_sync(0xFFFFFFFF, skip)) {
213-
__syncthreads();
214-
continue;
215-
}
216-
#endif // GGML_USE_HIP
217197
}
218198

219199
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,

ggml/src/ggml-cuda/fattn-vec-f32.cuh

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f32(
1616
const char * __restrict__ K,
1717
const char * __restrict__ V,
1818
const char * __restrict__ mask,
19+
const int * __restrict__ KV_max,
1920
float * __restrict__ dst,
2021
float2 * __restrict__ dst_meta,
2122
const float scale,
@@ -183,10 +184,11 @@ static __global__ void flash_attn_vec_ext_f32(
183184

184185
float VKQ[ncols] = {0.0f};
185186

187+
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
186188
K += blockIdx.y*D * nb11;
187189
V += blockIdx.y*D * nb21;
188190
maskh += blockIdx.y*D;
189-
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D,
191+
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*D,
190192
// Increment pointers after each loop:
191193
K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
192194

@@ -197,28 +199,7 @@ static __global__ void flash_attn_vec_ext_f32(
197199
for (int j = 0; j < ncols; ++j) {
198200
maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + tid]);
199201
}
200-
201202
__syncthreads();
202-
203-
// When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.
204-
// In such cases, skip the KV slice.
205-
// On AMD __all_sync would not work correctly because it assumes a warp size of 64.
206-
#ifndef GGML_USE_HIP
207-
bool skip = true;
208-
#pragma unroll
209-
for (int j = 0; j < ncols; ++j) {
210-
#pragma unroll
211-
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
212-
const int i = i0 + threadIdx.x;
213-
214-
skip = skip && isinf(maskf_shared[j*D + i]);
215-
}
216-
}
217-
if (__all_sync(0xFFFFFFFF, skip)) {
218-
__syncthreads();
219-
continue;
220-
}
221-
#endif // GGML_USE_HIP
222203
}
223204

224205
float kqmax_new_arr[ncols];

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ static __global__ void flash_attn_ext_f16(
2929
const char * __restrict__ K,
3030
const char * __restrict__ V,
3131
const char * __restrict__ mask,
32+
const int * __restrict__ KV_max,
3233
float * __restrict__ dst,
3334
float2 * __restrict__ dst_meta,
3435
const float scale,
@@ -165,7 +166,8 @@ static __global__ void flash_attn_ext_f16(
165166
__syncthreads();
166167

167168
// Iterate over ne11 == previous tokens:
168-
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) {
169+
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
170+
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) {
169171
// Calculate tile of KQ:
170172
#pragma unroll
171173
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {

ggml/src/ggml-cuda/fattn.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
315315

316316
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
317317
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
318-
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion;
318+
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies &&
319+
(Q->ne[3] > 1 || cc < GGML_CUDA_CC_ADA_LOVELACE) && !mma_needs_data_conversion;
319320
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
320321
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
321322
if (prec == GGML_PREC_DEFAULT) {

0 commit comments

Comments
 (0)