@@ -3646,16 +3646,16 @@ kernel void kernel_flash_attn_ext_vec(
36463646 constexpr short DV4 = DV/4 ;
36473647 constexpr short NW = N_SIMDWIDTH;
36483648 constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
3649- constexpr short SH = 2 *C; // shared memory per simdgroup
3649+ constexpr short SH = 4 *C; // shared memory per simdgroup
36503650
36513651 const short T = DK + nsg*SH; // shared memory size per query in (half)
36523652
3653- // threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3654- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0 *DK); // same as above but in q4_t
3655- threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
3656- threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
3657- threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*DK); // scratch buffer for mask
3658- threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
3653+ // threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3654+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0 *DK); // same as above but in q4_t
3655+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
3656+ threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
3657+ threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2 * C + Q*DK); // scratch buffer for mask
3658+ threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
36593659
36603660 // store the result for all queries in local memory (the O matrix from the paper)
36613661 o4_t lo[DV4/NL];
@@ -3836,7 +3836,7 @@ kernel void kernel_flash_attn_ext_vec(
38363836 v4_t mv;
38373837 deq_v_t4 (pv4 + i/nl_v, i%nl_v, mv);
38383838
3839- lo[ii/NL] += dot (( float4) mv, (float4) ms );
3839+ lo[ii/NL] += o4_t ( float4 (mv)* float4 (ms) );
38403840 }
38413841 }
38423842 }
@@ -3907,11 +3907,11 @@ kernel void kernel_flash_attn_ext_vec(
39073907 // parallel reduce
39083908 for (short r = nsg/2 ; r > 0 ; r >>= 1 ) {
39093909 if (sgitg < r) {
3910- const float S0 = ss[ 0 ];
3911- const float S1 = ss[r*SH + 0 ];
3910+ const float S0 = ss[ 0 ];
3911+ const float S1 = ss[r*(SH/ 2 ) + 0 ];
39123912
3913- const float M0 = ss[ 1 ];
3914- const float M1 = ss[r*SH + 1 ];
3913+ const float M0 = ss[ 1 ];
3914+ const float M1 = ss[r*(SH/ 2 ) + 1 ];
39153915
39163916 const float M = max (M0, M1);
39173917
0 commit comments