Skip to content

Commit eedb5ce

Browse files
authored
Merge branch 'ikawrakow:main' into main
2 parents f147107 + c108e4b commit eedb5ce

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

ggml/src/iqk/fa/iqk_fa_templates.h

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,7 +1355,11 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
13551355
KQHelper::convert(q_step, stride_q, q, q_f16);
13561356
#endif
13571357
auto mr = mask;
1358-
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
1358+
auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m);
1359+
int ik = nk1 - k_step;
1360+
for (; ik >=0 && Mc[ik] != 0; ik -= k_step);
1361+
ik += k_step;
1362+
for (int k1 = 0; k1 < ik/k_step; ++k1) {
13591363
#ifdef __aarch64__
13601364
KQHelper::multiply_mask_kq(kh, Dk, stride_m, q_f16, mr, fms);
13611365
#else
@@ -1415,6 +1419,8 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
14151419
HelperQ80::convert<Dk>(q_step, stride_q, q, q8r);
14161420
auto mr = mask;
14171421
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
1422+
auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m);
1423+
if (Mc[0] != 0) break;
14181424
HelperQ80R8<Dk>::repack(k_step, kh.block, kh.stride, q8r8);
14191425
KQHelper::mul_mask_kq(khr8, stride_m, q8r, mr, fms);
14201426
fqkv.accumulate_qkv(vh, fms);
@@ -1441,7 +1447,11 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
14411447
perf.accum_nolock(0, t1);
14421448
#endif
14431449
auto mr = mask;
1444-
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
1450+
auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m);
1451+
int ik = nk1 - k_step;
1452+
for (; ik >=0 && Mc[ik] != 0; ik -= k_step);
1453+
ik += k_step;
1454+
for (int k1 = 0; k1 < ik/k_step; ++k1) {
14451455
#if FA_TIMING
14461456
t1 = Perf::cur_time();
14471457
KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms);
@@ -1959,7 +1969,11 @@ struct FlashAttnBF16 {
19591969
perf.accum_nolock(0, t1);
19601970
#endif
19611971
auto mr = mask;
1962-
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
1972+
auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m);
1973+
int ik = nk1 - k_step;
1974+
for (; ik >=0 && Mc[ik] != 0; ik -= k_step);
1975+
ik += k_step;
1976+
for (int k1 = 0; k1 < ik/k_step; ++k1) {
19631977
#if FA_TIMING
19641978
//t1 = Perf::cur_time();
19651979
FlashQKbf16<Dk, q_step, k_step>::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf);

0 commit comments

Comments
 (0)