@@ -2844,7 +2844,7 @@ kernel void kernel_flash_attn_ext(
28442844 const short D8 = D/8 ;
28452845 const short D16 = D/16 ;
28462846 const short NW = N_SIMDWIDTH;
2847- const short SH = (2 *C + Q); // shared memory per simdgroup in (half )
2847+ const short SH = (2 *C + Q); // shared memory per simdgroup (s_t == float )
28482848
28492849 const short TS = nsg*SH; // shared memory size per query in (s_t == float)
28502850 const short T = D + 2 *TS; // shared memory size per query in (half)
@@ -3353,16 +3353,17 @@ kernel void kernel_flash_attn_ext_vec(
33533353 const short D16 = D/16 ;
33543354 const short NW = N_SIMDWIDTH;
33553355 const short NW4 = NW/4 ;
3356- const short SH = C; // shared memory per simdgroup in (half)
3356+ const short SH = 2 * C; // shared memory per simdgroup
33573357
33583358 const short T = D + 2 *nsg*SH; // shared memory size per query in (half)
33593359
3360- // threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
3361- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0 *D); // same as above but in half4
3362- threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0 *D); // same as above but in half4x4
3363- threadgroup s_t * ss = (threadgroup s_t *) (shared + 2 *sgitg*SH + Q*D); // scratch buffer for attention
3364- threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + 2 *sgitg*SH + Q*D); // same as above but in half4
3365- threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results
3360+ // threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
3361+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0 *D); // same as above but in q4_t
3362+ threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0 *D); // same as above but in q4x4_t
3363+ threadgroup s_t * ss = (threadgroup s_t *) (shared + 2 *sgitg*SH + Q*D); // scratch buffer for attention
3364+ threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + 2 *sgitg*SH + Q*D); // same as above but in s4_t
3365+ threadgroup half * sm = (threadgroup half *) (shared + 2 *sgitg*SH + SH + Q*D); // scratch buffer for mask
3366+ threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results
33663367
33673368 // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
33683369 o4x4_t lo[D16/NW4];
@@ -3412,8 +3413,10 @@ kernel void kernel_flash_attn_ext_vec(
34123413 mq[ii/NW4] = sq4x4[ii + tx];
34133414 }
34143415
3416+ const bool has_mask = mask != q;
3417+
34153418 // pointer to the mask
3416- device const half * mp = (device const half *) (mask + iq1*nb31);
3419+ device const half * pm = (device const half *) (mask + iq1*nb31);
34173420
34183421 half slope = 1 .0f ;
34193422
@@ -3435,6 +3438,10 @@ kernel void kernel_flash_attn_ext_vec(
34353438 break ;
34363439 }
34373440
3441+ if (has_mask) {
3442+ sm[tiisg] = pm[ic + tiisg];
3443+ }
3444+
34383445 // Q*K^T
34393446 {
34403447 // each simdgroup processes 1 query and 4 keys
@@ -3476,7 +3483,7 @@ kernel void kernel_flash_attn_ext_vec(
34763483 mqk = logit_softcap*precise::tanh (mqk);
34773484 }
34783485
3479- mqk += ( s_t ) ((mask != q) ? (( float ) mp[ic + 4 *cc + ty]) *slope : ( float ) 0 . 0f ) ;
3486+ mqk += sm[ 4 *cc + ty]*slope;
34803487
34813488 ss[4 *cc + ty] = mqk;
34823489 }
0 commit comments