@@ -3356,7 +3356,7 @@ kernel void kernel_flash_attn_ext_vec(
33563356 const short D4 = D/4 ;
33573357 const short D16 = D/16 ;
33583358 const short NW = N_SIMDWIDTH;
3359- const short NW4 = NW/4 ;
3359+ const short NL = NW/4 ;
33603360 const short SH = 2 *C; // shared memory per simdgroup
33613361
33623362 const short T = D + nsg*SH; // shared memory size per query in (half)
@@ -3370,7 +3370,7 @@ kernel void kernel_flash_attn_ext_vec(
33703370 threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results
33713371
33723372 // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
3373- o4x4_t lo[D16/NW4 ];
3373+ o4x4_t lo[D16/NL ];
33743374
33753375 // load heads from Q to shared memory
33763376 device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
@@ -3384,7 +3384,7 @@ kernel void kernel_flash_attn_ext_vec(
33843384 }
33853385
33863386 // zero out lo
3387- for (short i = 0 ; i < D16/NW4; i += NW4 ) {
3387+ for (short i = 0 ; i < D16/NL; ++i ) {
33883388 lo[i] = (o4x4_t ) 0 .0f ;
33893389 }
33903390
@@ -3400,8 +3400,8 @@ kernel void kernel_flash_attn_ext_vec(
34003400 half M = -__FLT16_MAX__/2 ;
34013401
34023402 // thread indices inside the simdgroup
3403- const short tx = tiisg%8 ;
3404- const short ty = tiisg/8 ;
3403+ const short tx = tiisg%NL ;
3404+ const short ty = tiisg/NL ;
34053405
34063406 // broadcast kv
34073407 // const short rk2 = ne02/ne12;
@@ -3411,10 +3411,10 @@ kernel void kernel_flash_attn_ext_vec(
34113411 const short ikv3 = iq3/(ne03/ne_12_3);
34123412
34133413 // load the queries from shared memory into local memory
3414- q4x4_t mq[D16/NW4 ];
3414+ q4x4_t mq[D16/NL ];
34153415
3416- for (short ii = 0 ; ii < D16; ii += NW4 ) {
3417- mq[ii/NW4 ] = sq4x4[ii + tx];
3416+ for (short ii = 0 ; ii < D16; ii += NL ) {
3417+ mq[ii/NL ] = sq4x4[ii + tx];
34183418 }
34193419
34203420 const bool has_mask = mask != q;
@@ -3455,17 +3455,17 @@ kernel void kernel_flash_attn_ext_vec(
34553455 device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4 *cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
34563456
34573457#pragma unroll
3458- for (short ii = 0 ; ii < D16; ii += NW4 ) {
3458+ for (short ii = 0 ; ii < D16; ii += NL ) {
34593459 const short i = ii + tx;
34603460
34613461 k4x4_t mk;
34623462 deq_k (pk + i/nl_k, i%nl_k, mk);
34633463
34643464 mqk +=
3465- dot (mq[ii/NW4 ][0 ], mk[0 ]) +
3466- dot (mq[ii/NW4 ][1 ], mk[1 ]) +
3467- dot (mq[ii/NW4 ][2 ], mk[2 ]) +
3468- dot (mq[ii/NW4 ][3 ], mk[3 ]);
3465+ dot (mq[ii/NL ][0 ], mk[0 ]) +
3466+ dot (mq[ii/NL ][1 ], mk[1 ]) +
3467+ dot (mq[ii/NL ][2 ], mk[2 ]) +
3468+ dot (mq[ii/NL ][3 ], mk[3 ]);
34693469 }
34703470
34713471 // simdgroup reduce
@@ -3513,8 +3513,8 @@ kernel void kernel_flash_attn_ext_vec(
35133513
35143514 // O = diag(ms)*O
35153515#pragma unroll
3516- for (short ii = 0 ; ii < D16; ii += NW4 ) {
3517- lo[ii/NW4 ] *= ms;
3516+ for (short ii = 0 ; ii < D16; ii += NL ) {
3517+ lo[ii/NL ] *= ms;
35183518 }
35193519 }
35203520
@@ -3529,13 +3529,13 @@ kernel void kernel_flash_attn_ext_vec(
35293529 const s4x4_t ms (ss[4 *cc + ty]);
35303530
35313531#pragma unroll
3532- for (short ii = 0 ; ii < D16; ii += NW4 ) {
3532+ for (short ii = 0 ; ii < D16; ii += NL ) {
35333533 const short i = ii + tx;
35343534
35353535 v4x4_t mv;
35363536 deq_v (pv4 + i/nl_v, i%nl_v, mv);
35373537
3538- lo[ii/NW4 ] += mv*ms;
3538+ lo[ii/NL ] += mv*ms;
35393539 }
35403540 }
35413541 }
@@ -3557,23 +3557,37 @@ kernel void kernel_flash_attn_ext_vec(
35573557 // [ 5, 13, 21, 29] -> [ 5]
35583558 // [ 6, 14, 22, 30] -> [ 6]
35593559 // [ 7, 15, 23, 31] -> [ 7]
3560- for (short ii = 0 ; ii < D16; ii += NW4) {
3561- lo[ii/NW4][0 ] += simd_shuffle_down (lo[ii/NW4][0 ], 16 );
3562- lo[ii/NW4][0 ] += simd_shuffle_down (lo[ii/NW4][0 ], 8 );
3563-
3564- lo[ii/NW4][1 ] += simd_shuffle_down (lo[ii/NW4][1 ], 16 );
3565- lo[ii/NW4][1 ] += simd_shuffle_down (lo[ii/NW4][1 ], 8 );
3566-
3567- lo[ii/NW4][2 ] += simd_shuffle_down (lo[ii/NW4][2 ], 16 );
3568- lo[ii/NW4][2 ] += simd_shuffle_down (lo[ii/NW4][2 ], 8 );
3569-
3570- lo[ii/NW4][3 ] += simd_shuffle_down (lo[ii/NW4][3 ], 16 );
3571- lo[ii/NW4][3 ] += simd_shuffle_down (lo[ii/NW4][3 ], 8 );
3560+ for (short ii = 0 ; ii < D16; ii += NL) {
3561+ lo[ii/NL][0 ] += simd_shuffle_down (lo[ii/NL][0 ], 16 );
3562+ lo[ii/NL][0 ] += simd_shuffle_down (lo[ii/NL][0 ], 8 );
3563+ // lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4);
3564+ // lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2);
3565+ // lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1);
3566+
3567+ lo[ii/NL][1 ] += simd_shuffle_down (lo[ii/NL][1 ], 16 );
3568+ lo[ii/NL][1 ] += simd_shuffle_down (lo[ii/NL][1 ], 8 );
3569+ // lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4);
3570+ // lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2);
3571+ // lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1);
3572+
3573+ lo[ii/NL][2 ] += simd_shuffle_down (lo[ii/NL][2 ], 16 );
3574+ lo[ii/NL][2 ] += simd_shuffle_down (lo[ii/NL][2 ], 8 );
3575+ // lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4);
3576+ // lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2);
3577+ // lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1);
3578+
3579+ lo[ii/NL][3 ] += simd_shuffle_down (lo[ii/NL][3 ], 16 );
3580+ lo[ii/NL][3 ] += simd_shuffle_down (lo[ii/NL][3 ], 8 );
3581+ // lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4);
3582+ // lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2);
3583+ // lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1);
35723584 }
35733585
3586+ threadgroup_barrier (mem_flags::mem_threadgroup);
3587+
35743588 // store results to shared memory
3575- for (short i = tiisg; i < D16; i += NW4 ) {
3576- sr4x4[i] = lo[i/NW4 ];
3589+ for (short i = tiisg; i < D16; i += NL ) {
3590+ sr4x4[i] = lo[i/NL ];
35773591 }
35783592
35793593 threadgroup_barrier (mem_flags::mem_threadgroup);
0 commit comments