33
44#include < mma.h>
55
6- #define FATTN_KQ_STRIDE 256
7- #define HALF_MAX_HALF __float2half (65504 .0f /2 ) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
6+ #define FATTN_KQ_STRIDE 256
7+ #define HALF_MAX_HALF __float2half (65504 .0f /2 ) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
8+ #define SOFTMAX_FTZ_THRESHOLD -20 .0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
89
910template <int D, int parallel_blocks> // D == head size
1011__launch_bounds__ (((D + WARP_SIZE - 1 ) / WARP_SIZE)*WARP_SIZE, 1)
@@ -338,10 +339,16 @@ static __global__ void flash_attn_ext_f16(
338339#pragma unroll
339340 for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE/2 ; k0 += WARP_SIZE) {
340341 const int k = k0 + threadIdx .x ;
341- KQ_max_new = __hmax2 (KQ_max_new, KQ2[j*(kqs_padded/2 ) + k]);
342+ half2 val = KQ2[j*(kqs_padded/2 ) + k];
343+ val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2 (0 .0f , 0 .0f );
344+ KQ_max_new = __hmax2 (KQ_max_new, val);
345+ KQ2[j*(kqs_padded/2 ) + k] = val;
342346 }
343347 KQ_max_new = __half2half2 (warp_reduce_max (__hmax (__low2half (KQ_max_new), __high2half (KQ_max_new))));
344- KQ_max_scale[j0/nwarps] = h2exp (KQ_max[j0/nwarps] - KQ_max_new);
348+ const half2 diff = KQ_max[j0/nwarps] - KQ_max_new;
349+ KQ_max_scale[j0/nwarps] = h2exp (diff);
350+ const uint ftz_mask = __hgt2_mask (diff, make_half2 (SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
351+ *((uint *) &KQ_max_scale[j0/nwarps]) &= ftz_mask;
345352 KQ_max[j0/nwarps] = KQ_max_new;
346353
347354 half2 KQ_rowsum_add = make_half2 (0 .0f , 0 .0f );
@@ -350,8 +357,10 @@ static __global__ void flash_attn_ext_f16(
350357 const int k = k0 + threadIdx .x ;
351358
352359 half2 val = KQ2[j*(kqs_padded/2 ) + k];
353- val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2 (0 .0f , 0 .0f );
354- val = h2exp (val - KQ_max[j0/nwarps]);
360+ const half2 diff = val - KQ_max[j0/nwarps];
361+ val = h2exp (diff);
362+ const uint ftz_mask = __hgt2_mask (diff, make_half2 (SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
363+ *((uint *) &val) &= ftz_mask;
355364 KQ_rowsum_add += val;
356365 KQ2[j*(kqs_padded/2 ) + k] = val;
357366 }
@@ -501,7 +510,10 @@ static __global__ void flash_attn_combine_results(
501510 float VKQ_denominator = 0 .0f ;
502511#pragma unroll
503512 for (int l = 0 ; l < parallel_blocks; ++l) {
504- float KQ_max_scale = hexp (__low2half (meta[l]) - kqmax);
513+ const half diff = __low2half (meta[l]) - kqmax;
514+ float KQ_max_scale = hexp (diff);
515+ const uint ftz_mask = 0xFFFFFFFF * (diff > __float2half (SOFTMAX_FTZ_THRESHOLD));
516+ *((uint *) &KQ_max_scale) &= ftz_mask;
505517
506518 VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim .y *D + blockIdx .y *D + tid];
507519 VKQ_denominator += KQ_max_scale * __high2float (meta[l]);
0 commit comments