@@ -3333,8 +3333,6 @@ kernel void kernel_flash_attn_ext(
33333333
33343334 threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0 *DK); // holds the query data
33353335 threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0 *DK); // same as above but in q4_t
3336- threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0 *DK); // reuse query data for accumulation
3337- threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0 *DK); // same as above but in o4_t
33383336 threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2 *sgitg*SH + 2 *Q*DK); // scratch buffer for attention, mask and diagonal matrix
33393337
33403338 threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4 *16 *KV) + Q*T); // scratch buffer to load K in shared memory
@@ -3548,20 +3546,20 @@ kernel void kernel_flash_attn_ext(
35483546
35493547 // O = diag(ms)*O
35503548 {
3551- s8x8_t mm ;
3552- simdgroup_load (mm , ss + 2 *C, TS, 0 , false );
3549+ s8x8_t ms ;
3550+ simdgroup_load (ms , ss + 2 *C, TS, 0 , false );
35533551
35543552 #pragma unroll(DV8)
35553553 for (short i = 0 ; i < DV8; ++i) {
3556- simdgroup_multiply (lo[i], mm , lo[i]);
3554+ simdgroup_multiply (lo[i], ms , lo[i]);
35573555 }
35583556 }
35593557
35603558 // O = O + (Q*K^T)*V
35613559 {
35623560 for (short cc = 0 ; cc < C/8 ; ++cc) {
3563- s8x8_t ms ;
3564- simdgroup_load (ms , ss + 8 *cc, TS, 0 , false );
3561+ s8x8_t vs ;
3562+ simdgroup_load (vs , ss + 8 *cc, TS, 0 , false );
35653563
35663564 if (is_same<vd4x4_t , v4x4_t >::value) {
35673565 // we can read directly from global memory
@@ -3572,7 +3570,7 @@ kernel void kernel_flash_attn_ext(
35723570 v8x8_t mv;
35733571 simdgroup_load (mv, pv + i*8 , args.nb21 /sizeof (v_t ), 0 , false ); // TODO: use ne20
35743572
3575- simdgroup_multiply_accumulate (lo[i], ms , mv, lo[i]);
3573+ simdgroup_multiply_accumulate (lo[i], vs , mv, lo[i]);
35763574 }
35773575 } else {
35783576 for (short ii = 0 ; ii < DV16; ii += 4 ) {
@@ -3593,10 +3591,10 @@ kernel void kernel_flash_attn_ext(
35933591 v8x8_t mv;
35943592
35953593 simdgroup_load (mv, sv + 16 *k + 0 *8 , 4 *16 , 0 , false );
3596- simdgroup_multiply_accumulate (lo[2 *(ii + k) + 0 ], ms , mv, lo[2 *(ii + k) + 0 ]);
3594+ simdgroup_multiply_accumulate (lo[2 *(ii + k) + 0 ], vs , mv, lo[2 *(ii + k) + 0 ]);
35973595
35983596 simdgroup_load (mv, sv + 16 *k + 1 *8 , 4 *16 , 0 , false );
3599- simdgroup_multiply_accumulate (lo[2 *(ii + k) + 1 ], ms , mv, lo[2 *(ii + k) + 1 ]);
3597+ simdgroup_multiply_accumulate (lo[2 *(ii + k) + 1 ], vs , mv, lo[2 *(ii + k) + 1 ]);
36003598 }
36013599 } else {
36023600 if (ii + tx < DV16) {
@@ -3611,10 +3609,10 @@ kernel void kernel_flash_attn_ext(
36113609 v8x8_t mv;
36123610
36133611 simdgroup_load (mv, sv + 16 *k + 0 *8 , 4 *16 , 0 , false );
3614- simdgroup_multiply_accumulate (lo[2 *(ii + k) + 0 ], ms , mv, lo[2 *(ii + k) + 0 ]);
3612+ simdgroup_multiply_accumulate (lo[2 *(ii + k) + 0 ], vs , mv, lo[2 *(ii + k) + 0 ]);
36153613
36163614 simdgroup_load (mv, sv + 16 *k + 1 *8 , 4 *16 , 0 , false );
3617- simdgroup_multiply_accumulate (lo[2 *(ii + k) + 1 ], ms , mv, lo[2 *(ii + k) + 1 ]);
3615+ simdgroup_multiply_accumulate (lo[2 *(ii + k) + 1 ], vs , mv, lo[2 *(ii + k) + 1 ]);
36183616 }
36193617 }
36203618 }
@@ -3624,83 +3622,80 @@ kernel void kernel_flash_attn_ext(
36243622 }
36253623
36263624 // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
3627- for (short j = 0 ; j < Q; ++j) {
3628- if (tiisg == 0 ) {
3629- ss[j*TS + 0 ] = S[j];
3630- ss[j*TS + 1 ] = M[j];
3631- }
3625+ for (short j = tiisg; j < Q; j += NW) {
3626+ ss[j*TS + 0 ] = S[j];
3627+ ss[j*TS + 1 ] = M[j];
36323628 }
36333629 }
36343630
3635- // reduce the warps sequentially
3636- for (ushort sg = 1 ; sg < nsg; ++sg) {
3637- threadgroup_barrier (mem_flags::mem_threadgroup);
3631+ threadgroup_barrier (mem_flags::mem_threadgroup);
36383632
3639- // each simdgroup stores its output to shared memory, reusing sq
3640- if (sgitg == sg) {
3641- for (short i = 0 ; i < DV8; ++i) {
3642- simdgroup_store (lo[i], so + i*8 , DV, 0 , false );
3643- }
3633+ threadgroup float * so = (threadgroup float *) (shmem_f16 + 0 *DK); // reuse query data for accumulation
3634+ threadgroup float4 * so4 = (threadgroup float4 *) (shmem_f16 + 0 *DK);
3635+
3636+ // store result to shared memory in F32
3637+ if (sgitg == 0 ) {
3638+ for (short i = 0 ; i < DV8; ++i) {
3639+ // simdgroup_store(lo[i], so + i*8, DV, 0, false);
3640+ simdgroup_float8x8 t (1 .0f );
3641+ simdgroup_multiply (t, lo[i], t);
3642+ simdgroup_store (t, so + i*8 , DV, 0 , false );
36443643 }
3644+ }
36453645
3646- threadgroup_barrier (mem_flags::mem_threadgroup);
3646+ threadgroup_barrier (mem_flags::mem_threadgroup);
36473647
3648- // the first simdgroup accumulates the results from the other simdgroups
3649- if (sgitg == 0 ) {
3650- for (short j = 0 ; j < Q; ++j) {
3651- const float S0 = ss[j*TS + 0 ];
3652- const float S1 = ss[j*TS + sg*SH + 0 ];
3648+ // reduce the warps sequentially
3649+ for (ushort sg = 1 ; sg < nsg; ++sg) {
3650+ if (sgitg == sg) {
3651+ for (short j = tiisg; j < Q; j += NW) {
3652+ const float S0 = ss[j*TS - 1 *SH + 0 ];
3653+ const float S1 = ss[j*TS + 0 ];
36533654
3654- const float M0 = ss[j*TS + 1 ];
3655- const float M1 = ss[j*TS + sg*SH + 1 ];
3655+ const float M0 = ss[j*TS - 1 *SH + 1 ];
3656+ const float M1 = ss[j*TS + 1 ];
36563657
36573658 const float M = max (M0, M1);
36583659
3659- const float ms0 = exp (M0 - M);
3660- const float ms1 = exp (M1 - M);
3660+ float ms0 = exp (M0 - M);
3661+ float ms1 = exp (M1 - M);
36613662
36623663 const float S = S0*ms0 + S1*ms1;
36633664
3664- if (tiisg == 0 ) {
3665- ss[j*TS + 0 ] = S;
3666- ss[j*TS + 1 ] = M;
3665+ ss[j*TS + 0 ] = S;
3666+ ss[j*TS + 1 ] = M;
36673667
3668- ss[j*TS + 2 *C + j ] = ms0;
3669- ss[j*TS + 2 *C + j + sg*SH] = ms1;
3670- }
3668+ ss[j*TS + 2 *C + j - 1 *SH] = ms0;
3669+ ss[j*TS + 2 *C + j ] = ms1;
36713670 }
36723671
3672+ // simdgroup_barrier(mem_flags::mem_threadgroup);
3673+
36733674 // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
36743675 {
36753676 s8x8_t ms0;
36763677 s8x8_t ms1;
36773678
3678- simdgroup_load (ms0, ss + 2 *C, TS, 0 , false );
3679- simdgroup_load (ms1, ss + 2 *C + sg*SH, TS, 0 , false );
3679+ simdgroup_load (ms0, ss + 2 *C - 1 *SH, TS, 0 , false );
3680+ simdgroup_load (ms1, ss + 2 *C, TS, 0 , false );
36803681
36813682 #pragma unroll(DV8)
36823683 for (short i = 0 ; i < DV8; ++i) {
3683- o8x8_t t;
3684+ simdgroup_float8x8 t;
36843685
36853686 simdgroup_load (t, so + i*8 , DV, 0 , false );
3686- simdgroup_multiply (t, ms1 , t);
3687+ simdgroup_multiply (t, ms0 , t);
36873688
3688- simdgroup_multiply_accumulate (lo[i], ms0, lo[i], t);
3689+ simdgroup_multiply_accumulate (t, ms1, lo[i], t);
3690+ simdgroup_store (t, so + i*8 , DV, 0 , false );
36893691 }
36903692 }
36913693 }
3692- }
36933694
3694- // store result to shared memory (reuse sq)
3695- if (sgitg == 0 ) {
3696- for (short i = 0 ; i < DV8; ++i) {
3697- simdgroup_store (lo[i], so + i*8 , DV, 0 , false );
3698- }
3695+ threadgroup_barrier (mem_flags::mem_threadgroup);
36993696 }
37003697
3701- threadgroup_barrier (mem_flags::mem_threadgroup);
3702-
3703- threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2 *Q*DK);
3698+ threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2 *(nsg-1 )*SH + 2 *Q*DK);
37043699
37053700 // final rescale with 1/S and store to global memory
37063701 for (short j = sgitg; j < Q && iq1 + j < args.ne01 ; j += nsg) {
@@ -3723,17 +3718,17 @@ kernel void kernel_flash_attn_ext(
37233718 half, half4x4, simdgroup_half8x8, \
37243719 float , simdgroup_float8x8, \
37253720 float , simdgroup_float8x8, \
3726- float , float4 , simdgroup_float8x8
3727- // half , half4 , simdgroup_half8x8
3721+ half , half4 , simdgroup_half8x8
3722+ // float , float4 , simdgroup_float8x8
37283723
37293724#define FA_TYPES_BF \
37303725 bfloat, bfloat4, simdgroup_bfloat8x8, \
37313726 bfloat, bfloat4x4, simdgroup_bfloat8x8, \
37323727 bfloat, bfloat4x4, simdgroup_bfloat8x8, \
37333728 float , simdgroup_float8x8, \
37343729 float , simdgroup_float8x8, \
3735- float , float4 , simdgroup_float8x8
3736- // half , half4 , simdgroup_half8x8
3730+ half , half4 , simdgroup_half8x8
3731+ // float , float4 , simdgroup_float8x8
37373732
37383733typedef decltype (kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 64 , 64 >) flash_attn_ext_t;
37393734
0 commit comments