@@ -2776,11 +2776,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
27762776        const  short  iv3 = iq3 / rv3;
27772777
27782778        //  load the queries from shared memory into local memory
2779-         float4 mq[D4];
2779+         float4 mq[D4/NW ];
27802780
27812781        for  (short  ii = 0 ; ii < D4; ii += NW) {
27822782            short  i = ii + tiisg;
2783-             mq[i ] = (float4) sq4[i];
2783+             mq[ii/NW ] = (float4) sq4[i];
27842784        }
27852785
27862786        //  pointer to the mask
@@ -2812,7 +2812,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
28122812                        mk[2 ] = (float4) pk4[i + 2 *(nb11/8 )];
28132813                        mk[3 ] = (float4) pk4[i + 3 *(nb11/8 )];
28142814
2815-                         mqk += (float4) (mq[i ] * mk);
2815+                         mqk += (float4) (mq[ii/NW ] * mk);
28162816                    }
28172817
28182818                    //  reduce the results from the threads in the simdgroup
@@ -2857,8 +2857,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
28572857                //  O = diag(ms)*O
28582858#pragma  unroll
28592859                for  (short  ii = 0 ; ii < D4; ii += NW) {
2860-                     const  short  i = ii + tiisg;
2861-                     lo[i/NW] *= ms;
2860+                     lo[ii/NW] *= ms;
28622861                }
28632862            }
28642863
@@ -2872,10 +2871,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
28722871                    for  (short  ii = 0 ; ii < D4; ii += NW) {
28732872                        const  short  i = ii + tiisg;
28742873
2875-                         lo[i /NW] += pv4[i + 0 *(nb21/8 )] * ss[4 *cc + 0 ];
2876-                         lo[i /NW] += pv4[i + 1 *(nb21/8 )] * ss[4 *cc + 1 ];
2877-                         lo[i /NW] += pv4[i + 2 *(nb21/8 )] * ss[4 *cc + 2 ];
2878-                         lo[i /NW] += pv4[i + 3 *(nb21/8 )] * ss[4 *cc + 3 ];
2874+                         lo[ii /NW] += pv4[i + 0 *(nb21/8 )] * ss[4 *cc + 0 ];
2875+                         lo[ii /NW] += pv4[i + 1 *(nb21/8 )] * ss[4 *cc + 1 ];
2876+                         lo[ii /NW] += pv4[i + 2 *(nb21/8 )] * ss[4 *cc + 2 ];
2877+                         lo[ii /NW] += pv4[i + 3 *(nb21/8 )] * ss[4 *cc + 3 ];
28792878                    }
28802879                }
28812880            }
0 commit comments