@@ -3184,8 +3184,8 @@ kernel void kernel_flash_attn_ext(
31843184    threadgroup_barrier (mem_flags::mem_threadgroup);
31853185
31863186    {
3187-         half  S[Q] = { [0  ... Q-1 ] = 0 .0f  };
3188-         half  M[Q] = { [0  ... Q-1 ] = -__FLT16_MAX__/2  };
3187+         float  S[Q] = { [0  ... Q-1 ] = 0 .0f  };
3188+         float  M[Q] = { [0  ... Q-1 ] = -__FLT16_MAX__/2  };
31893189
31903190        //  thread indices inside the simdgroup
31913191        //  TODO: see if we can utilize quad-group functions for better performance
@@ -3202,13 +3202,13 @@ kernel void kernel_flash_attn_ext(
32023202
32033203        const  bool  has_mask = mask != q;
32043204
3205-         half  slope = 1 .0f ;
3205+         float  slope = 1 .0f ;
32063206
32073207        //  ALiBi
32083208        if  (args.max_bias  > 0 .0f ) {
32093209            const  short  h = iq2;
32103210
3211-             const  half   base = h < args.n_head_log2  ? args.m0  : args.m1 ;
3211+             const  float  base = h < args.n_head_log2  ? args.m0  : args.m1 ;
32123212            const  short  exph = h < args.n_head_log2  ? h + 1  : 2 *(h - args.n_head_log2 ) + 1 ;
32133213
32143214            slope = pow (base, exph);
@@ -3224,14 +3224,14 @@ kernel void kernel_flash_attn_ext(
32243224
32253225            if  (has_mask) {
32263226                //  used to detect blocks full of -INF
3227-                 half  smax = -INFINITY;
3227+                 float  smax = -INFINITY;
32283228
32293229                //  load the mask in shared memory
32303230                #pragma  unroll(Q)
32313231                for  (short  j = 0 ; j < Q; ++j) {
32323232                    device const  half * pm = (device const  half *) ((device const  char  *) mask + (iq1 + j)*args.nb31 );
32333233
3234-                     const  half  m = pm[ic + tiisg];
3234+                     const  float  m = pm[ic + tiisg];
32353235
32363236                    ss[j*TS + C + tiisg] = m;
32373237                    smax = max (smax, m);
@@ -3327,10 +3327,10 @@ kernel void kernel_flash_attn_ext(
33273327            //  online softmax
33283328            {
33293329                for  (ushort j = 0 ; j < Q; ++j) {
3330-                     const  half  m = M[j];
3330+                     const  float  m = M[j];
33313331
33323332                    //  scale and apply the logitcap / mask
3333-                     half  s = ss[j*TS + tiisg]*args.scale ;
3333+                     float  s = ss[j*TS + tiisg]*args.scale ;
33343334
33353335                    if  (args.logit_softcap  != 0 .0f ) {
33363336                        s = args.logit_softcap *precise::tanh (s);
@@ -3341,8 +3341,8 @@ kernel void kernel_flash_attn_ext(
33413341
33423342                    M[j] = simd_max (max (M[j], s));
33433343
3344-                     const  half  ms = exp (m - M[j]);
3345-                     const  half  vs = exp (s - M[j]);
3344+                     const  float  ms = exp (m - M[j]);
3345+                     const  float  vs = exp (s - M[j]);
33463346
33473347                    S[j] = S[j]*ms + simd_sum (vs);
33483348
@@ -3444,8 +3444,8 @@ kernel void kernel_flash_attn_ext(
34443444
34453445    //  reduce the warps sequentially
34463446    for  (ushort sg = 1 ; sg < nsg; ++sg) {
3447-         half  S = { 0 .0f  };
3448-         half  M = { -__FLT16_MAX__/2  };
3447+         float  S = { 0 .0f  };
3448+         float  M = { -__FLT16_MAX__/2  };
34493449
34503450        threadgroup_barrier (mem_flags::mem_threadgroup);
34513451
@@ -3461,16 +3461,16 @@ kernel void kernel_flash_attn_ext(
34613461        //  the first simdgroup accumulates the results from the other simdgroups
34623462        if  (sgitg == 0 ) {
34633463            for  (short  j = 0 ; j < Q; ++j) {
3464-                 const  half  S0 = ss[j*TS +         0 ];
3465-                 const  half  S1 = ss[j*TS + sg*SH + 0 ];
3464+                 const  float  S0 = ss[j*TS +         0 ];
3465+                 const  float  S1 = ss[j*TS + sg*SH + 0 ];
34663466
3467-                 const  half  M0 = ss[j*TS +         1 ];
3468-                 const  half  M1 = ss[j*TS + sg*SH + 1 ];
3467+                 const  float  M0 = ss[j*TS +         1 ];
3468+                 const  float  M1 = ss[j*TS + sg*SH + 1 ];
34693469
34703470                M = max (M0, M1);
34713471
3472-                 const  half  ms0 = exp (M0 - M);
3473-                 const  half  ms1 = exp (M1 - M);
3472+                 const  float  ms0 = exp (M0 - M);
3473+                 const  float  ms1 = exp (M1 - M);
34743474
34753475                S = S0*ms0 + S1*ms1;
34763476
@@ -3684,8 +3684,8 @@ kernel void kernel_flash_attn_ext_vec(
36843684    threadgroup_barrier (mem_flags::mem_threadgroup);
36853685
36863686    {
3687-         half  S = 0 .0f ;
3688-         half  M = -__FLT16_MAX__/2 ;
3687+         float  S = 0 .0f ;
3688+         float  M = -__FLT16_MAX__/2 ;
36893689
36903690        //  thread indices inside the simdgroup
36913691        const  short  tx = tiisg%NL;
@@ -3703,13 +3703,13 @@ kernel void kernel_flash_attn_ext_vec(
37033703        //  pointer to the mask
37043704        device const  half * pm = (device const  half *) (mask + iq1*args.nb31 );
37053705
3706-         half  slope = 1 .0f ;
3706+         float  slope = 1 .0f ;
37073707
37083708        //  ALiBi
37093709        if  (args.max_bias  > 0 .0f ) {
37103710            const  short  h = iq2;
37113711
3712-             const  half   base = h < args.n_head_log2  ? args.m0  : args.m1 ;
3712+             const  float  base = h < args.n_head_log2  ? args.m0  : args.m1 ;
37133713            const  short  exph = h < args.n_head_log2  ? h + 1  : 2 *(h - args.n_head_log2 ) + 1 ;
37143714
37153715            slope = pow (base, exph);
@@ -3799,13 +3799,13 @@ kernel void kernel_flash_attn_ext_vec(
37993799
38003800            //  online softmax
38013801            {
3802-                 const  half  m = M;
3803-                 const  half  s = ss[tiisg];
3802+                 const  float  m = M;
3803+                 const  float  s = ss[tiisg];
38043804
38053805                M = simd_max (max (M, s));
38063806
3807-                 const  half  ms = exp (m - M);
3808-                 const  half  vs = exp (s - M);
3807+                 const  float  ms = exp (m - M);
3808+                 const  float  vs = exp (s - M);
38093809
38103810                S = S*ms + simd_sum (vs);
38113811
@@ -3836,7 +3836,7 @@ kernel void kernel_flash_attn_ext_vec(
38363836                        v4_t  mv;
38373837                        deq_v_t4 (pv4 + i/nl_v, i%nl_v, mv);
38383838
3839-                         lo[ii/NL] += mv*ms ;
3839+                         lo[ii/NL] += dot ((float4) mv, (float4) ms) ;
38403840                    }
38413841                }
38423842            }
@@ -3907,18 +3907,18 @@ kernel void kernel_flash_attn_ext_vec(
39073907    //  parallel reduce
39083908    for  (short  r = nsg/2 ; r > 0 ; r >>= 1 ) {
39093909        if  (sgitg < r) {
3910-             const  half  S0 = ss[       0 ];
3911-             const  half  S1 = ss[r*SH + 0 ];
3910+             const  float  S0 = ss[       0 ];
3911+             const  float  S1 = ss[r*SH + 0 ];
39123912
3913-             const  half  M0 = ss[       1 ];
3914-             const  half  M1 = ss[r*SH + 1 ];
3913+             const  float  M0 = ss[       1 ];
3914+             const  float  M1 = ss[r*SH + 1 ];
39153915
3916-             const  half  M = max (M0, M1);
3916+             const  float  M = max (M0, M1);
39173917
3918-             const  half  ms0 = exp (M0 - M);
3919-             const  half  ms1 = exp (M1 - M);
3918+             const  float  ms0 = exp (M0 - M);
3919+             const  float  ms1 = exp (M1 - M);
39203920
3921-             const  half  S = S0*ms0 + S1*ms1;
3921+             const  float  S = S0*ms0 + S1*ms1;
39223922
39233923            if  (tiisg == 0 ) {
39243924                ss[0 ] = S;
@@ -3950,11 +3950,11 @@ kernel void kernel_flash_attn_ext_vec(
39503950//        in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
39513951// 
39523952#define  FA_TYPES  \
3953-            half4, \
3954-            half4, \
3955-            half4, \
3956-     float ,        \
3957-     half,  half4 , \
3953+            half4,   \
3954+            half4,   \
3955+            half4,   \
3956+     float ,          \
3957+     float , float4 , \
39583958           half4
39593959
39603960typedef  decltype (kernel_flash_attn_ext_vec<FA_TYPES, half4, 1 , dequantize_f16_t4, half4, 1 , dequantize_f16_t4, 128 , 128 , 4 >) flash_attn_ext_vec_t;
0 commit comments