Skip to content

Commit e1e56f7

Browse files
committed
metal : FA remove mq registers
1 parent a444d39 commit e1e56f7

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3199,13 +3199,6 @@ kernel void kernel_flash_attn_ext(
31993199
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
32003200
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
32013201

3202-
// load the queries from shared memory into local memory
3203-
q8x8_t mq[DK8];
3204-
3205-
for (short i = 0; i < DK8; ++i) {
3206-
simdgroup_load(mq[i], sq + i*8, DK);
3207-
}
3208-
32093202
const bool has_mask = mask != q;
32103203

32113204
half slope = 1.0f;
@@ -3265,7 +3258,9 @@ kernel void kernel_flash_attn_ext(
32653258
k8x8_t mk;
32663259
simdgroup_load(mk, pk + i*8, args.nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10
32673260

3268-
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
3261+
q8x8_t mq;
3262+
simdgroup_load(mq, sq + i*8, DK);
3263+
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
32693264
}
32703265
} else {
32713266
for (short ii = 0; ii < DK16; ii += 4) {
@@ -3284,12 +3279,15 @@ kernel void kernel_flash_attn_ext(
32843279
#pragma unroll(4)
32853280
for (short k = 0; k < 4; ++k) {
32863281
k8x8_t mk;
3282+
q8x8_t mq;
32873283

32883284
simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
3289-
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
3285+
simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
3286+
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
32903287

32913288
simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
3292-
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
3289+
simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
3290+
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
32933291
}
32943292
} else {
32953293
if (ii + tx < DK16) {
@@ -3302,12 +3300,15 @@ kernel void kernel_flash_attn_ext(
33023300

33033301
for (short k = 0; k < 4 && ii + k < DK16; ++k) {
33043302
k8x8_t mk;
3303+
q8x8_t mq;
33053304

33063305
simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
3307-
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
3306+
simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
3307+
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
33083308

33093309
simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
3310-
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
3310+
simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
3311+
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
33113312
}
33123313
}
33133314
}

0 commit comments

Comments
 (0)