File tree Expand file tree Collapse file tree 1 file changed +7
-7
lines changed Expand file tree Collapse file tree 1 file changed +7
-7
lines changed Original file line number Diff line number Diff line change @@ -2631,11 +2631,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
26312631 const short iv3 = iq3 / rv3;
26322632
26332633 // load the queries from shared memory into local memory
2634- half4 mq[D4];
2634+ float4 mq[D4];
26352635
26362636 for (short ii = 0 ; ii < D4; ii += NW) {
26372637 short i = ii + tiisg;
2638- mq[i] = sq4[i];
2638+ mq[i] = (float4) sq4[i];
26392639 }
26402640
26412641 // pointer to the mask
@@ -2661,11 +2661,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
26612661 for (short ii = 0 ; ii < D4; ii += NW) {
26622662 const short i = ii + tiisg;
26632663
2664- half4x4 mk;
2665- mk[0 ] = pk4[i + 0 *(nb11/8 )];
2666- mk[1 ] = pk4[i + 1 *(nb11/8 )];
2667- mk[2 ] = pk4[i + 2 *(nb11/8 )];
2668- mk[3 ] = pk4[i + 3 *(nb11/8 )];
2664+ float4x4 mk;
2665+ mk[0 ] = (float4) pk4[i + 0 *(nb11/8 )];
2666+ mk[1 ] = (float4) pk4[i + 1 *(nb11/8 )];
2667+ mk[2 ] = (float4) pk4[i + 2 *(nb11/8 )];
2668+ mk[3 ] = (float4) pk4[i + 3 *(nb11/8 )];
26692669
26702670 mqk += (float4) (mq[i] * mk);
26712671 }
You can’t perform that action at this time.
0 commit comments