Skip to content

Commit b46e828

Browse files
committed
Revert "CUDA: add attention sinks for tile and wmma (ggml-org#15178)"
This reverts commit 34c9d76.
1 parent 0b83819 commit b46e828

File tree

4 files changed

+24
-115
lines changed

4 files changed

+24
-115
lines changed

ggml/src/ggml-cuda/fattn-tile-f16.cu

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,10 @@ static __global__ void flash_attn_tile_ext_f16(
4949
const int sequence = blockIdx.z / ne02;
5050
const int head = blockIdx.z - sequence*ne02;
5151
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
52-
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
53-
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
54-
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
55-
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
56-
const float * sinksf = (const float *) (sinks);
52+
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
53+
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
54+
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
55+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
5756

5857
const int stride_KV2 = nb11 / sizeof(half2);
5958

@@ -243,31 +242,6 @@ static __global__ void flash_attn_tile_ext_f16(
243242
__syncthreads();
244243
}
245244

246-
//Attention sink: adjust running max and sum once per head
247-
if (sinksf && blockIdx.y == 0) {
248-
const half sink = __float2half(sinksf[head]);
249-
250-
#pragma unroll
251-
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
252-
half kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
253-
kqmax_new_j = warp_reduce_max(kqmax_new_j);
254-
255-
const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new_j));
256-
kqmax[j0/nwarps] = kqmax_new_j;
257-
258-
const half val = hexp(sink - kqmax[j0/nwarps]);
259-
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
260-
if (threadIdx.x == 0) {
261-
kqsum[j0/nwarps].x = __hadd(kqsum[j0/nwarps].x, val);
262-
}
263-
264-
#pragma unroll
265-
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
266-
VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale;
267-
}
268-
}
269-
}
270-
271245
float2 * dst2 = (float2 *) dst;
272246

273247
#pragma unroll

ggml/src/ggml-cuda/fattn-tile-f32.cu

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,10 @@ static __global__ void flash_attn_tile_ext_f32(
6060
const int sequence = blockIdx.z / ne02;
6161
const int head = blockIdx.z - sequence*ne02;
6262
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
63-
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
64-
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
65-
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
66-
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
67-
const float * sinksf = (const float *) (sinks);
63+
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
64+
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
65+
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
66+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
6867

6968
const int stride_KV2 = nb11 / sizeof(half2);
7069

@@ -253,33 +252,6 @@ static __global__ void flash_attn_tile_ext_f32(
253252
__syncthreads();
254253
}
255254

256-
257-
//Attention sink: adjust running max and sum once per head
258-
if (sinksf && blockIdx.y == 0) {
259-
const float sink = sinksf[head];
260-
261-
#pragma unroll
262-
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
263-
float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
264-
kqmax_new_j = warp_reduce_max(kqmax_new_j);
265-
266-
const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j);
267-
kqmax[j0/nwarps] = kqmax_new_j;
268-
269-
const float val = expf(sink - kqmax[j0/nwarps]);
270-
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
271-
if (threadIdx.x == 0) {
272-
kqsum[j0/nwarps] += val;
273-
}
274-
275-
#pragma unroll
276-
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
277-
VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale;
278-
VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale;
279-
}
280-
}
281-
}
282-
283255
float2 * dst2 = (float2 *) dst;
284256

285257
#pragma unroll

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 5 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,11 @@ static __global__ void flash_attn_ext_f16(
8282
const int sequence = blockIdx.z / ne02;
8383
const int head = blockIdx.z - sequence*ne02;
8484
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
85-
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
86-
const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
87-
const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
88-
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
89-
const half2 * mask2 = (const half2 *) maskh;
90-
const float * sinksf = (const float *) sinks;
85+
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
86+
const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
87+
const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
88+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
89+
const half2 * mask2 = (const half2 *) maskh;
9190

9291
const int stride_Q = nb01 / sizeof(float);
9392
const int stride_KV = nb11 / sizeof(half);
@@ -382,53 +381,6 @@ static __global__ void flash_attn_ext_f16(
382381
__syncthreads();
383382
}
384383

385-
// Apply attention sinks
386-
if (sinksf && blockIdx.y == 0) {
387-
const float sinkf = sinksf[head];
388-
const half sinkh = __float2half(sinkf);
389-
390-
#pragma unroll
391-
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
392-
const int j = j0 + threadIdx.y;
393-
394-
if (std::is_same<KQ_acc_t, float>::value) {
395-
float kqmax_new = fmaxf(KQ_max_f[j0/nwarps], sinkf);
396-
397-
const float KQ_max_scale = expf(KQ_max_f[j0/nwarps] - kqmax_new);
398-
KQ_max_f[j0/nwarps] = kqmax_new;
399-
400-
KQ_rowsum_f[j0/nwarps] = KQ_rowsum_f[j0/nwarps] * KQ_max_scale + expf(sinkf - KQ_max_f[j0/nwarps]);
401-
402-
const half2 scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
403-
#pragma unroll
404-
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
405-
const int i = i0 + threadIdx.x;
406-
if (i0 + warp_size > D/2 && i >= D/2) break;
407-
VKQ2[j*(D_padded/2) + i] *= scale_h2;
408-
}
409-
} else {
410-
half kqmax_old = __low2half(KQ_max_h2[j0/nwarps]);
411-
half kqmax_new = fmaxf(kqmax_old, sinkh);
412-
KQ_max_h2[j0/nwarps] = __half2half2(kqmax_new);
413-
414-
const half KQ_max_scale_h = hexp(kqmax_old - kqmax_new);
415-
const half2 KQ_max_scale = __half2half2(KQ_max_scale_h);
416-
417-
KQ_rowsum_h2[j0/nwarps] = KQ_rowsum_h2[j0/nwarps] * KQ_max_scale;
418-
const half val = hexp(sinkh - kqmax_new);
419-
KQ_rowsum_h2[j0/nwarps].x = __hadd(KQ_rowsum_h2[j0/nwarps].x, val);
420-
421-
#pragma unroll
422-
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
423-
const int i = i0 + threadIdx.x;
424-
if (i0 + warp_size > D/2 && i >= D/2) break;
425-
VKQ2[j*(D_padded/2) + i] *= KQ_max_scale;
426-
}
427-
}
428-
}
429-
430-
__syncthreads();
431-
}
432384
#pragma unroll
433385
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
434386
const int j_VKQ = j0 + threadIdx.y;

ggml/src/ggml-cuda/fattn.cu

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,12 +274,23 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
274274
const ggml_tensor * K = dst->src[1];
275275
const ggml_tensor * V = dst->src[2];
276276
const ggml_tensor * mask = dst->src[3];
277+
const ggml_tensor * sinks = dst->src[4];
277278

278279
ggml_cuda_set_device(ctx.device);
279280
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
280281
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
281282
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
282283

284+
// TODO: currently only vec implementation for sinks is supported [TAG_ATTN_SINKS]
285+
if (sinks && !fp16_mma_available(cc)) {
286+
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
287+
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
288+
} else {
289+
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
290+
}
291+
return;
292+
}
293+
283294
#if defined(GGML_HIP_ROCWMMA_FATTN)
284295
if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) {
285296
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);

0 commit comments

Comments
 (0)