@@ -2204,11 +2204,7 @@ kernel void kernel_flash_attn_ext_f16(
22042204 // pointer to the mask
22052205 device const half * mp = (device const half *) (mask + iq1*nb31);
22062206
2207- // prepare diagonal scale matrix
2208- simdgroup_float8x8 mscale (scale);
2209-
2210- // prepare diagonal slope matrix
2211- simdgroup_float8x8 mslope (1 .0f );
2207+ float slope = 1 .0f ;
22122208
22132209 // ALiBi
22142210 if (max_bias > 0 .0f ) {
@@ -2217,7 +2213,7 @@ kernel void kernel_flash_attn_ext_f16(
22172213 const float base = h < n_head_log2 ? m0 : m1;
22182214 const int exph = h < n_head_log2 ? h + 1 : 2 *(h - n_head_log2) + 1 ;
22192215
2220- mslope = simdgroup_float8x8 ( pow (base, exph) );
2216+ slope = pow (base, exph);
22212217 }
22222218
22232219 // loop over the KV cache
@@ -2242,18 +2238,20 @@ kernel void kernel_flash_attn_ext_f16(
22422238 simdgroup_multiply_accumulate (mqk, mq[i], mk, mqk);
22432239 }
22442240
2241+ simdgroup_store (mqk, ss + 8 *cc, TF, 0 , false );
2242+
2243+ const short tx = tiisg%4 ;
2244+ const short ty = tiisg/4 ;
2245+
22452246 if (mask != q) {
22462247 // mqk = mqk*scale + mask*slope
2247- simdgroup_half8x8 mm;
2248- simdgroup_load (mm, mp + ic + 8 *cc, nb31/sizeof (half), 0 , false );
2249- simdgroup_multiply (mm, mslope, mm);
2250- simdgroup_multiply_accumulate (mqk, mqk, mscale, mm);
2248+ ss[8 *cc + ty*TF + 2 *tx + 0 ] = scale*ss[8 *cc + ty*TF + 2 *tx + 0 ] + slope*mp[ic + 8 *cc + ty*nb31/sizeof (half) + 2 *tx + 0 ];
2249+ ss[8 *cc + ty*TF + 2 *tx + 1 ] = scale*ss[8 *cc + ty*TF + 2 *tx + 1 ] + slope*mp[ic + 8 *cc + ty*nb31/sizeof (half) + 2 *tx + 1 ];
22512250 } else {
22522251 // mqk = mqk*scale
2253- simdgroup_multiply (mqk, mscale, mqk);
2252+ ss[8 *cc + ty*TF + 2 *tx + 0 ] *= scale;
2253+ ss[8 *cc + ty*TF + 2 *tx + 1 ] *= scale;
22542254 }
2255-
2256- simdgroup_store (mqk, ss + 8 *cc, TF, 0 , false );
22572255 }
22582256 }
22592257
@@ -2816,8 +2814,7 @@ kernel void kernel_cpy_f32_f16(
28162814 for (int64_t i00 = tpitg.x ; i00 < ne00; i00 += ntg.x ) {
28172815 device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
28182816
2819- // TODO: is there a better way to handle -INFINITY?
2820- dst_data[i00] = src[0 ] == -INFINITY ? -MAXHALF : src[0 ];
2817+ dst_data[i00] = src[0 ];
28212818 }
28222819}
28232820
0 commit comments