Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 70 additions & 17 deletions ggml/src/ggml-cuda/fattn-mma-f16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const half2 * const __restrict__ K_h2,
const half2 * const __restrict__ V_h2,
const half2 * const __restrict__ mask_h2,
const float * const __restrict__ sinks_f,
float2 * const __restrict__ dstk,
float2 * const __restrict__ dstk_fixup,
const float scale,
Expand Down Expand Up @@ -957,6 +958,52 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
}
}

// If attention sinks are used, potentially re-scale if KQ_max is small.
// Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
// so it's being done unconditionally for every thread.
if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
float KQ_max_scale[cols_per_thread];
#pragma unroll
for (int col = 0; col < cols_per_thread; ++col) {
static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
const float sink = sinks_f[jc % ncols2];

const float KQ_max_new = fmaxf(KQ_max[col], sink);
const float KQ_max_diff = KQ_max[col] - KQ_max_new;
KQ_max_scale[col] = expf(KQ_max_diff);
KQ_max[col] = KQ_max_new;

*((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;

const float KQ_max_add = expf(sink - KQ_max_new);
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
}

if (ntiles == 1) {
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
#pragma unroll
for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
#pragma unroll
for (int l = 0; l < tile_C_VKQ::ne; ++l) {
VKQ_C[i].x[l] *= KQ_max_scale_h2;
}
}
} else {
#pragma unroll
for (int col = 0; col < cols_per_thread; ++col) {
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
#pragma unroll
for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) {
#pragma unroll
for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
}
}
}
}
}

// Combine VKQ accumulator values if np > 1.
// It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
// So also write VKQ accumulators to shared memory in column-major format if np == 1.
Expand Down Expand Up @@ -1271,18 +1318,21 @@ static __global__ void flash_attn_ext_f16(

while (kbc < kbc_stop && kb0_stop == iter_k) {
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.

const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
const int head0 = zt * ncols2;

const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);

const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;

const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;

const int kb0_start_kernel = kb0_start * kb_niter;
int kb0_stop_kernel = kb0_stop * kb_niter;
Expand All @@ -1295,12 +1345,12 @@ static __global__ void flash_attn_ext_f16(
if (kb0_start == 0) {
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
} else {
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
}

Expand All @@ -1316,18 +1366,21 @@ static __global__ void flash_attn_ext_f16(
}

const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.

const int head0 = zt * ncols2;

const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);

const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;

const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;

const int kb0_start_kernel = kb0_start * kb_niter;
int kb0_stop_kernel = kb0_stop * kb_niter;
Expand All @@ -1339,7 +1392,7 @@ static __global__ void flash_attn_ext_f16(
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
constexpr bool needs_fixup = false;
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
#else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);

// TODO: currently only vec implementation for sinks is supported [TAG_ATTN_SINKS]
if (sinks) {
if (sinks && !fp16_mma_available(cc)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JohannesGaessler I think this breaks Volta, since fp16_mma_available is true but the wmma kernel doesn't yet support attention sinks

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, this should be turing_mma_available.

if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
} else {
Expand Down
3 changes: 2 additions & 1 deletion ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3532,7 +3532,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
}
// TODO: more general-purpose attention sink support [TAG_ATTN_SINKS]
if (op->src[4] && op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) { // currently only sinks for head_size 64 and 128 are supported
if (op->src[4] && !fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the !fp16_mma_available check needed here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for model support because the vector kernels I think cover all currently available models with attention sinks. But this enables running tests with attention sinks and head sizes != 64/128 so I thought it would be better to adjust.

&& op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) {
return false;
}
if (op->src[0]->ne[0] == 192) {
Expand Down
Loading