@@ -3184,8 +3184,8 @@ kernel void kernel_flash_attn_ext(
31843184 threadgroup_barrier (mem_flags::mem_threadgroup);
31853185
31863186 {
3187- half S[Q] = { [0 ... Q-1 ] = 0 .0f };
3188- half M[Q] = { [0 ... Q-1 ] = -__FLT16_MAX__/2 };
3187+ float S[Q] = { [0 ... Q-1 ] = 0 .0f };
3188+ float M[Q] = { [0 ... Q-1 ] = -__FLT16_MAX__/2 };
31893189
31903190 // thread indices inside the simdgroup
31913191 // TODO: see if we can utilize quad-group functions for better performance
@@ -3202,13 +3202,13 @@ kernel void kernel_flash_attn_ext(
32023202
32033203 const bool has_mask = mask != q;
32043204
3205- half slope = 1 .0f ;
3205+ float slope = 1 .0f ;
32063206
32073207 // ALiBi
32083208 if (args.max_bias > 0 .0f ) {
32093209 const short h = iq2;
32103210
3211- const half base = h < args.n_head_log2 ? args.m0 : args.m1 ;
3211+ const float base = h < args.n_head_log2 ? args.m0 : args.m1 ;
32123212 const short exph = h < args.n_head_log2 ? h + 1 : 2 *(h - args.n_head_log2 ) + 1 ;
32133213
32143214 slope = pow (base, exph);
@@ -3224,14 +3224,14 @@ kernel void kernel_flash_attn_ext(
32243224
32253225 if (has_mask) {
32263226 // used to detect blocks full of -INF
3227- half smax = -INFINITY;
3227+ float smax = -INFINITY;
32283228
32293229 // load the mask in shared memory
32303230 #pragma unroll(Q)
32313231 for (short j = 0 ; j < Q; ++j) {
32323232 device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 );
32333233
3234- const half m = pm[ic + tiisg];
3234+ const float m = pm[ic + tiisg];
32353235
32363236 ss[j*TS + C + tiisg] = m;
32373237 smax = max (smax, m);
@@ -3327,10 +3327,10 @@ kernel void kernel_flash_attn_ext(
33273327 // online softmax
33283328 {
33293329 for (ushort j = 0 ; j < Q; ++j) {
3330- const half m = M[j];
3330+ const float m = M[j];
33313331
33323332 // scale and apply the logitcap / mask
3333- half s = ss[j*TS + tiisg]*args.scale ;
3333+ float s = ss[j*TS + tiisg]*args.scale ;
33343334
33353335 if (args.logit_softcap != 0 .0f ) {
33363336 s = args.logit_softcap *precise::tanh (s);
@@ -3341,8 +3341,8 @@ kernel void kernel_flash_attn_ext(
33413341
33423342 M[j] = simd_max (max (M[j], s));
33433343
3344- const half ms = exp (m - M[j]);
3345- const half vs = exp (s - M[j]);
3344+ const float ms = exp (m - M[j]);
3345+ const float vs = exp (s - M[j]);
33463346
33473347 S[j] = S[j]*ms + simd_sum (vs);
33483348
@@ -3444,8 +3444,8 @@ kernel void kernel_flash_attn_ext(
34443444
34453445 // reduce the warps sequentially
34463446 for (ushort sg = 1 ; sg < nsg; ++sg) {
3447- half S = { 0 .0f };
3448- half M = { -__FLT16_MAX__/2 };
3447+ float S = { 0 .0f };
3448+ float M = { -__FLT16_MAX__/2 };
34493449
34503450 threadgroup_barrier (mem_flags::mem_threadgroup);
34513451
@@ -3461,16 +3461,16 @@ kernel void kernel_flash_attn_ext(
34613461 // the first simdgroup accumulates the results from the other simdgroups
34623462 if (sgitg == 0 ) {
34633463 for (short j = 0 ; j < Q; ++j) {
3464- const half S0 = ss[j*TS + 0 ];
3465- const half S1 = ss[j*TS + sg*SH + 0 ];
3464+ const float S0 = ss[j*TS + 0 ];
3465+ const float S1 = ss[j*TS + sg*SH + 0 ];
34663466
3467- const half M0 = ss[j*TS + 1 ];
3468- const half M1 = ss[j*TS + sg*SH + 1 ];
3467+ const float M0 = ss[j*TS + 1 ];
3468+ const float M1 = ss[j*TS + sg*SH + 1 ];
34693469
34703470 M = max (M0, M1);
34713471
3472- const half ms0 = exp (M0 - M);
3473- const half ms1 = exp (M1 - M);
3472+ const float ms0 = exp (M0 - M);
3473+ const float ms1 = exp (M1 - M);
34743474
34753475 S = S0*ms0 + S1*ms1;
34763476
@@ -3646,16 +3646,16 @@ kernel void kernel_flash_attn_ext_vec(
36463646 constexpr short DV4 = DV/4 ;
36473647 constexpr short NW = N_SIMDWIDTH;
36483648 constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
3649- constexpr short SH = 2 *C; // shared memory per simdgroup
3649+ constexpr short SH = 4 *C; // shared memory per simdgroup
36503650
36513651 const short T = DK + nsg*SH; // shared memory size per query in (half)
36523652
3653- // threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3654- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0 *DK); // same as above but in q4_t
3655- threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
3656- threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
3657- threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*DK); // scratch buffer for mask
3658- threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
3653+ // threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3654+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0 *DK); // same as above but in q4_t
3655+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
3656+ threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
3657+ threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2 * C + Q*DK); // scratch buffer for mask
3658+ threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
36593659
36603660 // store the result for all queries in local memory (the O matrix from the paper)
36613661 o4_t lo[DV4/NL];
@@ -3684,8 +3684,8 @@ kernel void kernel_flash_attn_ext_vec(
36843684 threadgroup_barrier (mem_flags::mem_threadgroup);
36853685
36863686 {
3687- half S = 0 .0f ;
3688- half M = -__FLT16_MAX__/2 ;
3687+ float S = 0 .0f ;
3688+ float M = -__FLT16_MAX__/2 ;
36893689
36903690 // thread indices inside the simdgroup
36913691 const short tx = tiisg%NL;
@@ -3703,13 +3703,13 @@ kernel void kernel_flash_attn_ext_vec(
37033703 // pointer to the mask
37043704 device const half * pm = (device const half *) (mask + iq1*args.nb31 );
37053705
3706- half slope = 1 .0f ;
3706+ float slope = 1 .0f ;
37073707
37083708 // ALiBi
37093709 if (args.max_bias > 0 .0f ) {
37103710 const short h = iq2;
37113711
3712- const half base = h < args.n_head_log2 ? args.m0 : args.m1 ;
3712+ const float base = h < args.n_head_log2 ? args.m0 : args.m1 ;
37133713 const short exph = h < args.n_head_log2 ? h + 1 : 2 *(h - args.n_head_log2 ) + 1 ;
37143714
37153715 slope = pow (base, exph);
@@ -3799,13 +3799,13 @@ kernel void kernel_flash_attn_ext_vec(
37993799
38003800 // online softmax
38013801 {
3802- const half m = M;
3803- const half s = ss[tiisg];
3802+ const float m = M;
3803+ const float s = ss[tiisg];
38043804
38053805 M = simd_max (max (M, s));
38063806
3807- const half ms = exp (m - M);
3808- const half vs = exp (s - M);
3807+ const float ms = exp (m - M);
3808+ const float vs = exp (s - M);
38093809
38103810 S = S*ms + simd_sum (vs);
38113811
@@ -3836,7 +3836,7 @@ kernel void kernel_flash_attn_ext_vec(
38363836 v4_t mv;
38373837 deq_v_t4 (pv4 + i/nl_v, i%nl_v, mv);
38383838
3839- lo[ii/NL] += mv*ms ;
3839+ lo[ii/NL] += o4_t ( float4 (mv)* float4 (ms)) ;
38403840 }
38413841 }
38423842 }
@@ -3907,18 +3907,18 @@ kernel void kernel_flash_attn_ext_vec(
39073907 // parallel reduce
39083908 for (short r = nsg/2 ; r > 0 ; r >>= 1 ) {
39093909 if (sgitg < r) {
3910- const half S0 = ss[ 0 ];
3911- const half S1 = ss[r*SH + 0 ];
3910+ const float S0 = ss[ 0 ];
3911+ const float S1 = ss[r*(SH/ 2 ) + 0 ];
39123912
3913- const half M0 = ss[ 1 ];
3914- const half M1 = ss[r*SH + 1 ];
3913+ const float M0 = ss[ 1 ];
3914+ const float M1 = ss[r*(SH/ 2 ) + 1 ];
39153915
3916- const half M = max (M0, M1);
3916+ const float M = max (M0, M1);
39173917
3918- const half ms0 = exp (M0 - M);
3919- const half ms1 = exp (M1 - M);
3918+ const float ms0 = exp (M0 - M);
3919+ const float ms1 = exp (M1 - M);
39203920
3921- const half S = S0*ms0 + S1*ms1;
3921+ const float S = S0*ms0 + S1*ms1;
39223922
39233923 if (tiisg == 0 ) {
39243924 ss[0 ] = S;
@@ -3950,11 +3950,11 @@ kernel void kernel_flash_attn_ext_vec(
39503950// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
39513951//
39523952#define FA_TYPES \
3953- half4, \
3954- half4, \
3955- half4, \
3956- float , \
3957- half, half4 , \
3953+ half4, \
3954+ half4, \
3955+ half4, \
3956+ float , \
3957+ float , float4 , \
39583958 half4
39593959
39603960typedef decltype (kernel_flash_attn_ext_vec<FA_TYPES, half4, 1 , dequantize_f16_t4, half4, 1 , dequantize_f16_t4, 128 , 128 , 4 >) flash_attn_ext_vec_t;
0 commit comments