File tree Expand file tree Collapse file tree 5 files changed +11
-5
lines changed Expand file tree Collapse file tree 5 files changed +11
-5
lines changed Original file line number Diff line number Diff line change 1010#define HALF_MAX_HALF __float2half (65504 .0f /2 ) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
1111#define SOFTMAX_FTZ_THRESHOLD -20 .0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
1212
13+ // log(2) = 0.6931, by adding this to the KQ maximum used for the softmax the numerical range representable
14+ // by the VKQ accumulators is effectively being shifted up by a factor of 8.
15+ // This reduces issues with numerical overflow but also causes larger values to be flushed to zero.
16+ // However, as the output from FlashAttention will usually be used as an input for a matrix multiplication this should be negligible.
17+ #define FATTN_KQ_MAX_OFFSET 0 .6931f
18+
1319typedef void (* fattn_kernel_t )(
1420 const char * __restrict__ Q,
1521 const char * __restrict__ K,
Original file line number Diff line number Diff line change @@ -532,7 +532,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
532532#pragma unroll
533533 for (int l = 0 ; l < T_C_KQ::ne; ++l) {
534534 if (!oob_check || k0 + T_C_KQ::get_i (l) < k_VKQ_sup) {
535- KQ_max_new[l % 2 ] = fmaxf (KQ_max_new[l % 2 ], KQ_C[k0/(np*T_C_KQ::I)].x [l]);
535+ KQ_max_new[l % 2 ] = fmaxf (KQ_max_new[l % 2 ], KQ_C[k0/(np*T_C_KQ::I)].x [l] + FATTN_KQ_MAX_OFFSET );
536536 }
537537 }
538538 }
@@ -585,7 +585,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
585585 for (int l = 0 ; l < T_C_KQ::ne; ++l) {
586586 if (!oob_check || k0 + T_C_KQ::get_j (l) < k_VKQ_sup) {
587587 // Turing + Volta:
588- KQ_max_new[(l/2 ) % 2 ] = fmaxf (KQ_max_new[(l/2 ) % 2 ], KQ_C[(k0/(np*T_C_KQ::J))].x [l]);
588+ KQ_max_new[(l/2 ) % 2 ] = fmaxf (KQ_max_new[(l/2 ) % 2 ], KQ_C[(k0/(np*T_C_KQ::J))].x [l] + FATTN_KQ_MAX_OFFSET );
589589 }
590590 }
591591 }
Original file line number Diff line number Diff line change @@ -572,7 +572,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
572572 KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) ?
573573 slope*__half2float (mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0 .0f ;
574574
575- KQ_max_new[jc0] = fmaxf (KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
575+ KQ_max_new[jc0] = fmaxf (KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] + FATTN_KQ_MAX_OFFSET );
576576 }
577577 }
578578
Original file line number Diff line number Diff line change @@ -270,7 +270,7 @@ static __global__ void flash_attn_ext_vec(
270270 sum += slope*__half2float (maskh[j*ne11 + i_KQ]);
271271 }
272272
273- KQ_max_new[j] = fmaxf (KQ_max_new[j], sum);
273+ KQ_max_new[j] = fmaxf (KQ_max_new[j], sum + FATTN_KQ_MAX_OFFSET );
274274
275275 if ((nthreads_KQ == WARP_SIZE ? threadIdx .x : threadIdx .x % nthreads_KQ) == uint32_t (i_KQ_0)) {
276276 KQ_reg[j] = sum;
Original file line number Diff line number Diff line change @@ -220,7 +220,7 @@ static __global__ void flash_attn_ext_f16(
220220
221221 KQ_f_tmp[k0/warp_size] += mask && ic0 + j < int (ne01.z ) ?
222222 __half2float (slopeh*maskh[j*(nb31/sizeof (half)) + k_VKQ_0 + k]) : 0 .0f ;
223- KQ_max_new = max (KQ_max_new, KQ_f_tmp[k0/warp_size]);
223+ KQ_max_new = max (KQ_max_new, KQ_f_tmp[k0/warp_size] + FATTN_KQ_MAX_OFFSET );
224224 }
225225 KQ_max_new = warp_reduce_max<warp_size>(KQ_max_new);
226226
You can’t perform that action at this time.
0 commit comments