@@ -2261,24 +2261,6 @@ kernel void kernel_flash_attn_ext_f16(
22612261 }
22622262
22632263 simdgroup_store (mqk, ss + 8 *cc, TF, 0 , false );
2264-
2265- const short tx = tiisg%4 ;
2266- const short ty = tiisg/4 ;
2267-
2268- // mqk = mqk*scale
2269- ss[8 *cc + ty*TF + 2 *tx + 0 ] *= scale;
2270- ss[8 *cc + ty*TF + 2 *tx + 1 ] *= scale;
2271-
2272- if (logit_softcap != 0 .0f ) {
2273- ss[8 *cc + ty*TF + 2 *tx + 0 ] = logit_softcap*precise::tanh (ss[8 *cc + ty*TF + 2 *tx + 0 ]);
2274- ss[8 *cc + ty*TF + 2 *tx + 1 ] = logit_softcap*precise::tanh (ss[8 *cc + ty*TF + 2 *tx + 1 ]);
2275- }
2276-
2277- if (mask != q) {
2278- // mqk = mqk + mask*slope
2279- ss[8 *cc + ty*TF + 2 *tx + 0 ] += slope*mp[ic + 8 *cc + ty*nb31/sizeof (half) + 2 *tx + 0 ];
2280- ss[8 *cc + ty*TF + 2 *tx + 1 ] += slope*mp[ic + 8 *cc + ty*nb31/sizeof (half) + 2 *tx + 1 ];
2281- }
22822264 }
22832265 }
22842266
@@ -2290,10 +2272,19 @@ kernel void kernel_flash_attn_ext_f16(
22902272 float ms[Q];
22912273
22922274 for (short j = 0 ; j < Q; ++j) {
2293- const short p = tiisg;
2294-
22952275 const float m = M[j];
2296- const float s = ss[j*TF + p];
2276+
2277+ // scale and apply the logitcap / mask
2278+ float s = ss[j*TF + tiisg]*scale;
2279+
2280+ if (logit_softcap != 0 .0f ) {
2281+ s = logit_softcap*precise::tanh (s);
2282+ }
2283+
2284+ if (mask != q) {
2285+ // mqk = mqk + mask*slope
2286+ s += slope*mp[ic + j*nb31/sizeof (half) + tiisg];
2287+ }
22972288
22982289 smax = simd_max (max (smax, s));
22992290 M[j] = simd_max (max (M[j], s));
@@ -2304,7 +2295,7 @@ kernel void kernel_flash_attn_ext_f16(
23042295 S[j] = S[j]*ms[j] + simd_sum (vs);
23052296
23062297 // the P matrix from the paper (Q rows, C columns)
2307- ss[j*TF + p ] = vs;
2298+ ss[j*TF + tiisg ] = vs;
23082299 }
23092300
23102301 // create a QxQ diagonal matrix for rescaling the output
0 commit comments