Skip to content

Commit e95d0bc

Browse files
CUDA: fix FA VKQ accumulator overflow (#17746)
1 parent 668ed76 commit e95d0bc

File tree

5 files changed

+11
-5
lines changed

5 files changed

+11
-5
lines changed

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@
1010
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
1111
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
1212

13+
// log(2) = 0.6931, by adding this to the KQ maximum used for the softmax the numerical range representable
14+
// by the VKQ accumulators is effectively being shifted up by a factor of 8.
15+
// This reduces issues with numerical overflow but also causes larger values to be flushed to zero.
16+
// However, as the output from FlashAttention will usually be used as an input for a matrix multiplication this should be negligible.
17+
#define FATTN_KQ_MAX_OFFSET 0.6931f
18+
1319
typedef void (* fattn_kernel_t)(
1420
const char * __restrict__ Q,
1521
const char * __restrict__ K,

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
532532
#pragma unroll
533533
for (int l = 0; l < T_C_KQ::ne; ++l) {
534534
if (!oob_check || k0 + T_C_KQ::get_i(l) < k_VKQ_sup) {
535-
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l]);
535+
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
536536
}
537537
}
538538
}
@@ -585,7 +585,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
585585
for (int l = 0; l < T_C_KQ::ne; ++l) {
586586
if (!oob_check || k0 + T_C_KQ::get_j(l) < k_VKQ_sup) {
587587
// Turing + Volta:
588-
KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l]);
588+
KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
589589
}
590590
}
591591
}

ggml/src/ggml-cuda/fattn-tile.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
572572
KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) ?
573573
slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f;
574574

575-
KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
575+
KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] + FATTN_KQ_MAX_OFFSET);
576576
}
577577
}
578578

ggml/src/ggml-cuda/fattn-vec.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ static __global__ void flash_attn_ext_vec(
270270
sum += slope*__half2float(maskh[j*ne11 + i_KQ]);
271271
}
272272

273-
KQ_max_new[j] = fmaxf(KQ_max_new[j], sum);
273+
KQ_max_new[j] = fmaxf(KQ_max_new[j], sum + FATTN_KQ_MAX_OFFSET);
274274

275275
if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == uint32_t(i_KQ_0)) {
276276
KQ_reg[j] = sum;

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ static __global__ void flash_attn_ext_f16(
220220

221221
KQ_f_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ?
222222
__half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
223-
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size]);
223+
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size] + FATTN_KQ_MAX_OFFSET);
224224
}
225225
KQ_max_new = warp_reduce_max<warp_size>(KQ_max_new);
226226

0 commit comments

Comments
 (0)