@@ -1355,7 +1355,11 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
13551355 KQHelper::convert (q_step, stride_q, q, q_f16);
13561356#endif
13571357 auto mr = mask;
1358- for (int k1 = 0 ; k1 < nk1/k_step; ++k1) {
1358+ auto Mc = (const uint16_t *)(mr + (q_step - 1 )*stride_m);
1359+ int ik = nk1 - k_step;
1360+ for (; ik >=0 && Mc[ik] != 0 ; ik -= k_step);
1361+ ik += k_step;
1362+ for (int k1 = 0 ; k1 < ik/k_step; ++k1) {
13591363#ifdef __aarch64__
13601364 KQHelper::multiply_mask_kq (kh, Dk, stride_m, q_f16, mr, fms);
13611365#else
@@ -1415,6 +1419,8 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
14151419 HelperQ80::convert<Dk>(q_step, stride_q, q, q8r);
14161420 auto mr = mask;
14171421 for (int k1 = 0 ; k1 < nk1/k_step; ++k1) {
1422+ auto Mc = (const uint16_t *)(mr + (q_step - 1 )*stride_m);
1423+ if (Mc[0 ] != 0 ) break ;
14181424 HelperQ80R8<Dk>::repack (k_step, kh.block , kh.stride , q8r8);
14191425 KQHelper::mul_mask_kq (khr8, stride_m, q8r, mr, fms);
14201426 fqkv.accumulate_qkv (vh, fms);
@@ -1441,7 +1447,11 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
14411447 perf.accum_nolock (0 , t1);
14421448#endif
14431449 auto mr = mask;
1444- for (int k1 = 0 ; k1 < nk1/k_step; ++k1) {
1450+ auto Mc = (const uint16_t *)(mr + (q_step - 1 )*stride_m);
1451+ int ik = nk1 - k_step;
1452+ for (; ik >=0 && Mc[ik] != 0 ; ik -= k_step);
1453+ ik += k_step;
1454+ for (int k1 = 0 ; k1 < ik/k_step; ++k1) {
14451455#if FA_TIMING
14461456 t1 = Perf::cur_time ();
14471457 KQHelper::mul_mask_kq (kh, stride_m, q8, mr, fms);
@@ -1959,7 +1969,11 @@ struct FlashAttnBF16 {
19591969 perf.accum_nolock (0 , t1);
19601970#endif
19611971 auto mr = mask;
1962- for (int k1 = 0 ; k1 < nk1/k_step; ++k1) {
1972+ auto Mc = (const uint16_t *)(mr + (q_step - 1 )*stride_m);
1973+ int ik = nk1 - k_step;
1974+ for (; ik >=0 && Mc[ik] != 0 ; ik -= k_step);
1975+ ik += k_step;
1976+ for (int k1 = 0 ; k1 < ik/k_step; ++k1) {
19631977#if FA_TIMING
19641978 // t1 = Perf::cur_time();
19651979 FlashQKbf16<Dk, q_step, k_step>::multiply_mask_kq (kh, stride_m, q_bf16, mr, fms, perf);
0 commit comments