Skip to content

Commit 34c9d76

Browse files
authored
CUDA: add attention sinks for tile and wmma (#15178)
* CUDA: add attention sinks for tile and wmma * Review: formatting changes + remove syncthreads from tile + remove warp_reduce_max from wmma
1 parent e54d41b commit 34c9d76

File tree

4 files changed

+115
-24
lines changed

4 files changed

+115
-24
lines changed

ggml/src/ggml-cuda/fattn-tile-f16.cu

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,11 @@ static __global__ void flash_attn_tile_ext_f16(
4949
const int sequence = blockIdx.z / ne02;
5050
const int head = blockIdx.z - sequence*ne02;
5151
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
52-
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
53-
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
54-
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
55-
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
52+
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
53+
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
54+
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
55+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
56+
const float * sinksf = (const float *) (sinks);
5657

5758
const int stride_KV2 = nb11 / sizeof(half2);
5859

@@ -242,6 +243,31 @@ static __global__ void flash_attn_tile_ext_f16(
242243
__syncthreads();
243244
}
244245

246+
//Attention sink: adjust running max and sum once per head
247+
if (sinksf && blockIdx.y == 0) {
248+
const half sink = __float2half(sinksf[head]);
249+
250+
#pragma unroll
251+
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
252+
half kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
253+
kqmax_new_j = warp_reduce_max(kqmax_new_j);
254+
255+
const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new_j));
256+
kqmax[j0/nwarps] = kqmax_new_j;
257+
258+
const half val = hexp(sink - kqmax[j0/nwarps]);
259+
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
260+
if (threadIdx.x == 0) {
261+
kqsum[j0/nwarps].x = __hadd(kqsum[j0/nwarps].x, val);
262+
}
263+
264+
#pragma unroll
265+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
266+
VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale;
267+
}
268+
}
269+
}
270+
245271
float2 * dst2 = (float2 *) dst;
246272

247273
#pragma unroll

ggml/src/ggml-cuda/fattn-tile-f32.cu

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,11 @@ static __global__ void flash_attn_tile_ext_f32(
6060
const int sequence = blockIdx.z / ne02;
6161
const int head = blockIdx.z - sequence*ne02;
6262
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
63-
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
64-
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
65-
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
66-
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
63+
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
64+
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
65+
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
66+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
67+
const float * sinksf = (const float *) (sinks);
6768

6869
const int stride_KV2 = nb11 / sizeof(half2);
6970

@@ -252,6 +253,33 @@ static __global__ void flash_attn_tile_ext_f32(
252253
__syncthreads();
253254
}
254255

256+
257+
//Attention sink: adjust running max and sum once per head
258+
if (sinksf && blockIdx.y == 0) {
259+
const float sink = sinksf[head];
260+
261+
#pragma unroll
262+
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
263+
float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
264+
kqmax_new_j = warp_reduce_max(kqmax_new_j);
265+
266+
const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j);
267+
kqmax[j0/nwarps] = kqmax_new_j;
268+
269+
const float val = expf(sink - kqmax[j0/nwarps]);
270+
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
271+
if (threadIdx.x == 0) {
272+
kqsum[j0/nwarps] += val;
273+
}
274+
275+
#pragma unroll
276+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
277+
VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale;
278+
VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale;
279+
}
280+
}
281+
}
282+
255283
float2 * dst2 = (float2 *) dst;
256284

257285
#pragma unroll

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,12 @@ static __global__ void flash_attn_ext_f16(
8282
const int sequence = blockIdx.z / ne02;
8383
const int head = blockIdx.z - sequence*ne02;
8484
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
85-
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
86-
const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
87-
const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
88-
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
89-
const half2 * mask2 = (const half2 *) maskh;
85+
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
86+
const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
87+
const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
88+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
89+
const half2 * mask2 = (const half2 *) maskh;
90+
const float * sinksf = (const float *) sinks;
9091

9192
const int stride_Q = nb01 / sizeof(float);
9293
const int stride_KV = nb11 / sizeof(half);
@@ -381,6 +382,53 @@ static __global__ void flash_attn_ext_f16(
381382
__syncthreads();
382383
}
383384

385+
// Apply attention sinks
386+
if (sinksf && blockIdx.y == 0) {
387+
const float sinkf = sinksf[head];
388+
const half sinkh = __float2half(sinkf);
389+
390+
#pragma unroll
391+
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
392+
const int j = j0 + threadIdx.y;
393+
394+
if (std::is_same<KQ_acc_t, float>::value) {
395+
float kqmax_new = fmaxf(KQ_max_f[j0/nwarps], sinkf);
396+
397+
const float KQ_max_scale = expf(KQ_max_f[j0/nwarps] - kqmax_new);
398+
KQ_max_f[j0/nwarps] = kqmax_new;
399+
400+
KQ_rowsum_f[j0/nwarps] = KQ_rowsum_f[j0/nwarps] * KQ_max_scale + expf(sinkf - KQ_max_f[j0/nwarps]);
401+
402+
const half2 scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
403+
#pragma unroll
404+
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
405+
const int i = i0 + threadIdx.x;
406+
if (i0 + warp_size > D/2 && i >= D/2) break;
407+
VKQ2[j*(D_padded/2) + i] *= scale_h2;
408+
}
409+
} else {
410+
half kqmax_old = __low2half(KQ_max_h2[j0/nwarps]);
411+
half kqmax_new = fmaxf(kqmax_old, sinkh);
412+
KQ_max_h2[j0/nwarps] = __half2half2(kqmax_new);
413+
414+
const half KQ_max_scale_h = hexp(kqmax_old - kqmax_new);
415+
const half2 KQ_max_scale = __half2half2(KQ_max_scale_h);
416+
417+
KQ_rowsum_h2[j0/nwarps] = KQ_rowsum_h2[j0/nwarps] * KQ_max_scale;
418+
const half val = hexp(sinkh - kqmax_new);
419+
KQ_rowsum_h2[j0/nwarps].x = __hadd(KQ_rowsum_h2[j0/nwarps].x, val);
420+
421+
#pragma unroll
422+
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
423+
const int i = i0 + threadIdx.x;
424+
if (i0 + warp_size > D/2 && i >= D/2) break;
425+
VKQ2[j*(D_padded/2) + i] *= KQ_max_scale;
426+
}
427+
}
428+
}
429+
430+
__syncthreads();
431+
}
384432
#pragma unroll
385433
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
386434
const int j_VKQ = j0 + threadIdx.y;

ggml/src/ggml-cuda/fattn.cu

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -274,23 +274,12 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
274274
const ggml_tensor * K = dst->src[1];
275275
const ggml_tensor * V = dst->src[2];
276276
const ggml_tensor * mask = dst->src[3];
277-
const ggml_tensor * sinks = dst->src[4];
278277

279278
ggml_cuda_set_device(ctx.device);
280279
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
281280
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
282281
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
283282

284-
// TODO: currently only vec implementation for sinks is supported [TAG_ATTN_SINKS]
285-
if (sinks && !fp16_mma_available(cc)) {
286-
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
287-
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
288-
} else {
289-
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
290-
}
291-
return;
292-
}
293-
294283
#if defined(GGML_HIP_ROCWMMA_FATTN)
295284
if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) {
296285
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);

0 commit comments

Comments
 (0)