@@ -3199,13 +3199,6 @@ kernel void kernel_flash_attn_ext(
31993199 const short ikv2 = iq2/(args.ne02 /args.ne_12_2 );
32003200 const short ikv3 = iq3/(args.ne03 /args.ne_12_3 );
32013201
3202- // load the queries from shared memory into local memory
3203- q8x8_t mq[DK8];
3204-
3205- for (short i = 0 ; i < DK8; ++i) {
3206- simdgroup_load (mq[i], sq + i*8 , DK);
3207- }
3208-
32093202 const bool has_mask = mask != q;
32103203
32113204 half slope = 1 .0f ;
@@ -3265,7 +3258,9 @@ kernel void kernel_flash_attn_ext(
32653258 k8x8_t mk;
32663259 simdgroup_load (mk, pk + i*8 , args.nb11 /sizeof (k_t ), 0 , true ); // transpose // TODO: use ne10
32673260
3268- simdgroup_multiply_accumulate (mqk, mq[i], mk, mqk);
3261+ q8x8_t mq;
3262+ simdgroup_load (mq, sq + i*8 , DK);
3263+ simdgroup_multiply_accumulate (mqk, mq, mk, mqk);
32693264 }
32703265 } else {
32713266 for (short ii = 0 ; ii < DK16; ii += 4 ) {
@@ -3284,12 +3279,15 @@ kernel void kernel_flash_attn_ext(
32843279 #pragma unroll(4)
32853280 for (short k = 0 ; k < 4 ; ++k) {
32863281 k8x8_t mk;
3282+ q8x8_t mq;
32873283
32883284 simdgroup_load (mk, sk + 16 *k + 0 *8 , 4 *16 , 0 , true ); // transpose
3289- simdgroup_multiply_accumulate (mqk, mq[2 *(ii + k) + 0 ], mk, mqk);
3285+ simdgroup_load (mq, sq + (2 *(ii + k) + 0 )*8 , DK);
3286+ simdgroup_multiply_accumulate (mqk, mq, mk, mqk);
32903287
32913288 simdgroup_load (mk, sk + 16 *k + 1 *8 , 4 *16 , 0 , true ); // transpose
3292- simdgroup_multiply_accumulate (mqk, mq[2 *(ii + k) + 1 ], mk, mqk);
3289+ simdgroup_load (mq, sq + (2 *(ii + k) + 1 )*8 , DK);
3290+ simdgroup_multiply_accumulate (mqk, mq, mk, mqk);
32933291 }
32943292 } else {
32953293 if (ii + tx < DK16) {
@@ -3302,12 +3300,15 @@ kernel void kernel_flash_attn_ext(
33023300
33033301 for (short k = 0 ; k < 4 && ii + k < DK16; ++k) {
33043302 k8x8_t mk;
3303+ q8x8_t mq;
33053304
33063305 simdgroup_load (mk, sk + 16 *k + 0 *8 , 4 *16 , 0 , true ); // transpose
3307- simdgroup_multiply_accumulate (mqk, mq[2 *(ii + k) + 0 ], mk, mqk);
3306+ simdgroup_load (mq, sq + (2 *(ii + k) + 0 )*8 , DK);
3307+ simdgroup_multiply_accumulate (mqk, mq, mk, mqk);
33083308
33093309 simdgroup_load (mk, sk + 16 *k + 1 *8 , 4 *16 , 0 , true ); // transpose
3310- simdgroup_multiply_accumulate (mqk, mq[2 *(ii + k) + 1 ], mk, mqk);
3310+ simdgroup_load (mq, sq + (2 *(ii + k) + 1 )*8 , DK);
3311+ simdgroup_multiply_accumulate (mqk, mq, mk, mqk);
33113312 }
33123313 }
33133314 }
0 commit comments