@@ -2776,11 +2776,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
27762776 const short iv3 = iq3 / rv3;
27772777
27782778 // load the queries from shared memory into local memory
2779- float4 mq[D4];
2779+ float4 mq[D4/NW ];
27802780
27812781 for (short ii = 0 ; ii < D4; ii += NW) {
27822782 short i = ii + tiisg;
2783- mq[i ] = (float4) sq4[i];
2783+ mq[ii/NW ] = (float4) sq4[i];
27842784 }
27852785
27862786 // pointer to the mask
@@ -2812,7 +2812,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
28122812 mk[2 ] = (float4) pk4[i + 2 *(nb11/8 )];
28132813 mk[3 ] = (float4) pk4[i + 3 *(nb11/8 )];
28142814
2815- mqk += (float4) (mq[i ] * mk);
2815+ mqk += (float4) (mq[ii/NW ] * mk);
28162816 }
28172817
28182818 // reduce the results from the threads in the simdgroup
@@ -2857,8 +2857,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
28572857 // O = diag(ms)*O
28582858#pragma unroll
28592859 for (short ii = 0 ; ii < D4; ii += NW) {
2860- const short i = ii + tiisg;
2861- lo[i/NW] *= ms;
2860+ lo[ii/NW] *= ms;
28622861 }
28632862 }
28642863
@@ -2872,10 +2871,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
28722871 for (short ii = 0 ; ii < D4; ii += NW) {
28732872 const short i = ii + tiisg;
28742873
2875- lo[i /NW] += pv4[i + 0 *(nb21/8 )] * ss[4 *cc + 0 ];
2876- lo[i /NW] += pv4[i + 1 *(nb21/8 )] * ss[4 *cc + 1 ];
2877- lo[i /NW] += pv4[i + 2 *(nb21/8 )] * ss[4 *cc + 2 ];
2878- lo[i /NW] += pv4[i + 3 *(nb21/8 )] * ss[4 *cc + 3 ];
2874+ lo[ii /NW] += pv4[i + 0 *(nb21/8 )] * ss[4 *cc + 0 ];
2875+ lo[ii /NW] += pv4[i + 1 *(nb21/8 )] * ss[4 *cc + 1 ];
2876+ lo[ii /NW] += pv4[i + 2 *(nb21/8 )] * ss[4 *cc + 2 ];
2877+ lo[ii /NW] += pv4[i + 3 *(nb21/8 )] * ss[4 *cc + 3 ];
28792878 }
28802879 }
28812880 }
0 commit comments