@@ -2049,27 +2049,24 @@ typedef void (flash_attn_ext_f16_t)(
20492049 device const char * v,
20502050 device const char * mask,
20512051 device float * dst,
2052- constant int64_t & ne00,
20532052 constant int64_t & ne01,
20542053 constant int64_t & ne02,
20552054 constant int64_t & ne03,
2056- constant uint64_t & nb00,
20572055 constant uint64_t & nb01,
20582056 constant uint64_t & nb02,
20592057 constant uint64_t & nb03,
2060- constant int64_t & ne10,
20612058 constant int64_t & ne11,
20622059 constant int64_t & ne12,
20632060 constant int64_t & ne13,
2064- constant uint64_t & nb10,
20652061 constant uint64_t & nb11,
20662062 constant uint64_t & nb12,
20672063 constant uint64_t & nb13,
2064+ constant uint64_t & nb21,
2065+ constant uint64_t & nb22,
2066+ constant uint64_t & nb23,
20682067 constant uint64_t & nb31,
2069- constant int64_t & ne0,
20702068 constant int64_t & ne1,
20712069 constant int64_t & ne2,
2072- constant int64_t & ne3,
20732070 constant float & scale,
20742071 constant float & max_bias,
20752072 constant float & m0,
@@ -2090,27 +2087,24 @@ kernel void kernel_flash_attn_ext_f16(
20902087 device const char * v,
20912088 device const char * mask,
20922089 device float * dst,
2093- constant int64_t & ne00,
20942090 constant int64_t & ne01,
20952091 constant int64_t & ne02,
20962092 constant int64_t & ne03,
2097- constant uint64_t & nb00,
20982093 constant uint64_t & nb01,
20992094 constant uint64_t & nb02,
21002095 constant uint64_t & nb03,
2101- constant int64_t & ne10,
21022096 constant int64_t & ne11,
21032097 constant int64_t & ne12,
21042098 constant int64_t & ne13,
2105- constant uint64_t & nb10,
21062099 constant uint64_t & nb11,
21072100 constant uint64_t & nb12,
21082101 constant uint64_t & nb13,
2102+ constant uint64_t & nb21,
2103+ constant uint64_t & nb22,
2104+ constant uint64_t & nb23,
21092105 constant uint64_t & nb31,
2110- constant int64_t & ne0,
21112106 constant int64_t & ne1,
21122107 constant int64_t & ne2,
2113- constant int64_t & ne3,
21142108 constant float & scale,
21152109 constant float & max_bias,
21162110 constant float & m0,
@@ -2180,10 +2174,6 @@ kernel void kernel_flash_attn_ext_f16(
21802174 const short ne22 = ne12;
21812175 const short ne23 = ne13;
21822176
2183- const uint nb21 = nb11;
2184- const uint nb22 = nb12;
2185- const uint nb23 = nb13;
2186-
21872177 // broadcast
21882178 const short rk2 = ne02/ne12;
21892179 const short rk3 = ne03/ne13;
@@ -2247,11 +2237,16 @@ kernel void kernel_flash_attn_ext_f16(
22472237 simdgroup_multiply_accumulate (mqk, mq[i], mk, mqk);
22482238 }
22492239
2250- // mqk = mqk*scale + mask*slope
2251- simdgroup_half8x8 mm;
2252- simdgroup_load (mm, mp + ic + 8 *cc, nb31/sizeof (half), 0 , false );
2253- simdgroup_multiply (mm, mslope, mm);
2254- simdgroup_multiply_accumulate (mqk, mqk, mscale, mm);
2240+ if (mask != q) {
2241+ // mqk = mqk*scale + mask*slope
2242+ simdgroup_half8x8 mm;
2243+ simdgroup_load (mm, mp + ic + 8 *cc, nb31/sizeof (half), 0 , false );
2244+ simdgroup_multiply (mm, mslope, mm);
2245+ simdgroup_multiply_accumulate (mqk, mqk, mscale, mm);
2246+ } else {
2247+ // mqk = mqk*scale
2248+ simdgroup_multiply (mqk, mscale, mqk);
2249+ }
22552250
22562251 simdgroup_store (mqk, ss + 8 *cc, TF, 0 , false );
22572252 }
@@ -2425,27 +2420,24 @@ kernel void kernel_flash_attn_ext_vec_f16(
24252420 device const char * v,
24262421 device const char * mask,
24272422 device float * dst,
2428- constant int64_t & ne00,
24292423 constant int64_t & ne01,
24302424 constant int64_t & ne02,
24312425 constant int64_t & ne03,
2432- constant uint64_t & nb00,
24332426 constant uint64_t & nb01,
24342427 constant uint64_t & nb02,
24352428 constant uint64_t & nb03,
2436- constant int64_t & ne10,
24372429 constant int64_t & ne11,
24382430 constant int64_t & ne12,
24392431 constant int64_t & ne13,
2440- constant uint64_t & nb10,
24412432 constant uint64_t & nb11,
24422433 constant uint64_t & nb12,
24432434 constant uint64_t & nb13,
2435+ constant uint64_t & nb21,
2436+ constant uint64_t & nb22,
2437+ constant uint64_t & nb23,
24442438 constant uint64_t & nb31,
2445- constant int64_t & ne0,
24462439 constant int64_t & ne1,
24472440 constant int64_t & ne2,
2448- constant int64_t & ne3,
24492441 constant float & scale,
24502442 constant float & max_bias,
24512443 constant float & m0,
@@ -2521,10 +2513,6 @@ kernel void kernel_flash_attn_ext_vec_f16(
25212513 const short ne22 = ne12;
25222514 const short ne23 = ne13;
25232515
2524- const uint nb21 = nb11;
2525- const uint nb22 = nb12;
2526- const uint nb23 = nb13;
2527-
25282516 // broadcast
25292517 const short rk2 = ne02/ne12;
25302518 const short rk3 = ne03/ne13;
@@ -2589,8 +2577,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
25892577
25902578 // mqk = mqk*scale + mask*slope
25912579 if (tiisg == 0 ) {
2592- float4 mm = (float4) mp4[ic/4 + cc];
2593- mqk = mqk*scale + mm*slope;
2580+ mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0 .0f );
25942581
25952582 ss4[cc] = mqk;
25962583 }
0 commit comments