File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -2776,11 +2776,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
27762776 const short iv3 = iq3 / rv3;
27772777
27782778 // load the queries from shared memory into local memory
2779- float4 mq[D4];
2779+ float4 mq[D4/NW ];
27802780
27812781 for (short ii = 0 ; ii < D4; ii += NW) {
27822782 short i = ii + tiisg;
2783- mq[i] = (float4) sq4[i];
2783+ mq[i/NW ] = (float4) sq4[i];
27842784 }
27852785
27862786 // pointer to the mask
@@ -2812,7 +2812,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
28122812 mk[2 ] = (float4) pk4[i + 2 *(nb11/8 )];
28132813 mk[3 ] = (float4) pk4[i + 3 *(nb11/8 )];
28142814
2815- mqk += (float4) (mq[i] * mk);
2815+ mqk += (float4) (mq[i/NW ] * mk);
28162816 }
28172817
28182818 // reduce the results from the threads in the simdgroup
You can’t perform that action at this time.
0 commit comments