Skip to content

Commit 8ae6cc3

Browse files
committed
CUDA: attention sinks for mma FlashAttention
Port from ggml-org/llama.cpp#15157
1 parent cd0d7f0 commit 8ae6cc3

File tree

3 files changed

+81
-16
lines changed

3 files changed

+81
-16
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3663,6 +3663,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
36633663
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
36643664
return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128;
36653665
#else
3666+
// TODO: more general-purpose attention sink support [TAG_ATTN_SINKS]
3667+
if (op->src[4] && !fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc)
3668+
&& op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) {
3669+
return false;
3670+
}
36663671
if (op->src[0]->ne[0] == 128) {
36673672
return true;
36683673
}

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 74 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
425425
const half2 * const __restrict__ K_h2,
426426
const half2 * const __restrict__ V_h2,
427427
const half2 * const __restrict__ mask_h2,
428+
const float * const __restrict__ sinks_f,
428429
float2 * const __restrict__ dstk,
429430
float2 * const __restrict__ dstk_fixup,
430431
const float scale,
@@ -584,6 +585,52 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
584585
}
585586
}
586587

588+
// If attention sinks are used, potentially re-scale if KQ_max is small.
589+
// Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
590+
// so it's being done unconditionally for every thread.
591+
if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
592+
float KQ_max_scale[cols_per_thread];
593+
#pragma unroll
594+
for (int col = 0; col < cols_per_thread; ++col) {
595+
static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
596+
const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
597+
const float sink = sinks_f[jc % ncols2];
598+
599+
const float KQ_max_new = fmaxf(KQ_max[col], sink);
600+
const float KQ_max_diff = KQ_max[col] - KQ_max_new;
601+
KQ_max_scale[col] = expf(KQ_max_diff);
602+
KQ_max[col] = KQ_max_new;
603+
604+
*((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
605+
606+
const float KQ_max_add = expf(sink - KQ_max_new);
607+
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
608+
}
609+
610+
if (ntiles == 1) {
611+
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
612+
#pragma unroll
613+
for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
614+
#pragma unroll
615+
for (int l = 0; l < tile_C_VKQ::ne; ++l) {
616+
VKQ_C[i].x[l] *= KQ_max_scale_h2;
617+
}
618+
}
619+
} else {
620+
#pragma unroll
621+
for (int col = 0; col < cols_per_thread; ++col) {
622+
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
623+
#pragma unroll
624+
for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) {
625+
#pragma unroll
626+
for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
627+
VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
628+
}
629+
}
630+
}
631+
}
632+
}
633+
587634
// Write VKQ accumulators to shared memory in column-major format.
588635
// It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
589636
// Also for np > 1 the combination is done via these values in shared memory.
@@ -889,15 +936,21 @@ static __global__ void flash_attn_mma_ext_f16(
889936
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
890937
while (kbc < kbc_stop && kb0_stop == iter_k) {
891938
const int channel = kbc / (iter_k*iter_j);
892-
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
939+
const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
940+
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
941+
942+
const int head0 = zt * ncols2;
893943

894-
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
895-
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
896-
const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape
944+
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
945+
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
897946
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
898-
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2);
947+
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
948+
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);
949+
950+
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
951+
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
899952

900-
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
953+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
901954

902955
const int kb0_start_kernel = kb0_start * kb_niter;
903956
const int kb0_stop_kernel = kb0_stop * kb_niter;
@@ -906,12 +959,12 @@ static __global__ void flash_attn_mma_ext_f16(
906959
if (kb0_start == 0) {
907960
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
908961
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
909-
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
962+
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
910963
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
911964
} else {
912965
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
913966
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
914-
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
967+
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
915968
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
916969
}
917970

@@ -927,23 +980,29 @@ static __global__ void flash_attn_mma_ext_f16(
927980
}
928981

929982
const int channel = kbc / (iter_k*iter_j);
930-
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
983+
const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
984+
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
931985

932-
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
933-
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
934-
const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape
986+
const int head0 = zt * ncols2;
987+
988+
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
989+
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
935990
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
936-
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2);
991+
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
992+
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);
993+
994+
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
995+
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
937996

938-
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
997+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
939998

940999
const int kb0_start_kernel = kb0_start * kb_niter;
9411000
const int kb0_stop_kernel = kb0_stop * kb_niter;
9421001

9431002
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
9441003
constexpr bool needs_fixup = false;
9451004
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
946-
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
1005+
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
9471006
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
9481007
#else
9491008
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);

ggml/src/ggml-cuda/fattn.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -461,13 +461,14 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
461461
const ggml_tensor * K = dst->src[1];
462462
const ggml_tensor * V = dst->src[2];
463463
const ggml_tensor * mask = dst->src[3];
464+
const ggml_tensor * sinks = dst->src[4];
464465

465466
ggml_cuda_set_device(ctx.device);
466467
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
467468
const int32_t precision = KQV->op_params[3];
468469

469470
// On AMD the tile kernels perform poorly, use the vec kernel instead:
470-
if (cc >= CC_OFFSET_AMD) {
471+
if (cc >= CC_OFFSET_AMD || (sinks && !fp16_mma_available(cc))) {
471472
if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
472473
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
473474
} else {

0 commit comments

Comments
 (0)