@@ -3406,8 +3406,6 @@ kernel void kernel_flash_attn_ext(
34063406
34073407 threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0 *DK); // holds the query data
34083408 threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0 *DK); // same as above but in q4_t
3409- threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0 *DK); // reuse query data for accumulation
3410- threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0 *DK); // same as above but in o4_t
34113409 threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2 *sgitg*SH + 2 *Q*DK); // scratch buffer for attention, mask and diagonal matrix
34123410
34133411 threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4 *16 *KV) + Q*T); // scratch buffer to load K in shared memory
@@ -3621,20 +3619,20 @@ kernel void kernel_flash_attn_ext(
36213619
36223620 // O = diag(ms)*O
36233621 {
3624- s8x8_t mm ;
3625- simdgroup_load (mm , ss + 2 *C, TS, 0 , false );
3622+ s8x8_t ms ;
3623+ simdgroup_load (ms , ss + 2 *C, TS, 0 , false );
36263624
36273625 #pragma unroll(DV8)
36283626 for (short i = 0 ; i < DV8; ++i) {
3629- simdgroup_multiply (lo[i], mm , lo[i]);
3627+ simdgroup_multiply (lo[i], ms , lo[i]);
36303628 }
36313629 }
36323630
36333631 // O = O + (Q*K^T)*V
36343632 {
36353633 for (short cc = 0 ; cc < C/8 ; ++cc) {
3636- s8x8_t ms ;
3637- simdgroup_load (ms , ss + 8 *cc, TS, 0 , false );
3634+ s8x8_t vs ;
3635+ simdgroup_load (vs , ss + 8 *cc, TS, 0 , false );
36383636
36393637 if (is_same<vd4x4_t , v4x4_t >::value) {
36403638 // we can read directly from global memory
@@ -3645,7 +3643,7 @@ kernel void kernel_flash_attn_ext(
36453643 v8x8_t mv;
36463644 simdgroup_load (mv, pv + i*8 , args.nb21 /sizeof (v_t ), 0 , false ); // TODO: use ne20
36473645
3648- simdgroup_multiply_accumulate (lo[i], ms , mv, lo[i]);
3646+ simdgroup_multiply_accumulate (lo[i], vs , mv, lo[i]);
36493647 }
36503648 } else {
36513649 for (short ii = 0 ; ii < DV16; ii += 4 ) {
@@ -3666,10 +3664,10 @@ kernel void kernel_flash_attn_ext(
36663664 v8x8_t mv;
36673665
36683666 simdgroup_load (mv, sv + 16 *k + 0 *8 , 4 *16 , 0 , false );
3669- simdgroup_multiply_accumulate (lo[2 *(ii + k) + 0 ], ms , mv, lo[2 *(ii + k) + 0 ]);
3667+ simdgroup_multiply_accumulate (lo[2 *(ii + k) + 0 ], vs , mv, lo[2 *(ii + k) + 0 ]);
36703668
36713669 simdgroup_load (mv, sv + 16 *k + 1 *8 , 4 *16 , 0 , false );
3672- simdgroup_multiply_accumulate (lo[2 *(ii + k) + 1 ], ms , mv, lo[2 *(ii + k) + 1 ]);
3670+ simdgroup_multiply_accumulate (lo[2 *(ii + k) + 1 ], vs , mv, lo[2 *(ii + k) + 1 ]);
36733671 }
36743672 } else {
36753673 if (ii + tx < DV16) {
@@ -3684,10 +3682,10 @@ kernel void kernel_flash_attn_ext(
36843682 v8x8_t mv;
36853683
36863684 simdgroup_load (mv, sv + 16 *k + 0 *8 , 4 *16 , 0 , false );
3687- simdgroup_multiply_accumulate (lo[2 *(ii + k) + 0 ], ms , mv, lo[2 *(ii + k) + 0 ]);
3685+ simdgroup_multiply_accumulate (lo[2 *(ii + k) + 0 ], vs , mv, lo[2 *(ii + k) + 0 ]);
36883686
36893687 simdgroup_load (mv, sv + 16 *k + 1 *8 , 4 *16 , 0 , false );
3690- simdgroup_multiply_accumulate (lo[2 *(ii + k) + 1 ], ms , mv, lo[2 *(ii + k) + 1 ]);
3688+ simdgroup_multiply_accumulate (lo[2 *(ii + k) + 1 ], vs , mv, lo[2 *(ii + k) + 1 ]);
36913689 }
36923690 }
36933691 }
@@ -3697,83 +3695,80 @@ kernel void kernel_flash_attn_ext(
36973695 }
36983696
36993697 // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
3700- for (short j = 0 ; j < Q; ++j) {
3701- if (tiisg == 0 ) {
3702- ss[j*TS + 0 ] = S[j];
3703- ss[j*TS + 1 ] = M[j];
3704- }
3698+ for (short j = tiisg; j < Q; j += NW) {
3699+ ss[j*TS + 0 ] = S[j];
3700+ ss[j*TS + 1 ] = M[j];
37053701 }
37063702 }
37073703
3708- // reduce the warps sequentially
3709- for (ushort sg = 1 ; sg < nsg; ++sg) {
3710- threadgroup_barrier (mem_flags::mem_threadgroup);
3704+ threadgroup_barrier (mem_flags::mem_threadgroup);
37113705
3712- // each simdgroup stores its output to shared memory, reusing sq
3713- if (sgitg == sg) {
3714- for (short i = 0 ; i < DV8; ++i) {
3715- simdgroup_store (lo[i], so + i*8 , DV, 0 , false );
3716- }
3706+ threadgroup float * so = (threadgroup float *) (shmem_f16 + 0 *DK); // reuse query data for accumulation
3707+ threadgroup float4 * so4 = (threadgroup float4 *) (shmem_f16 + 0 *DK);
3708+
3709+ // store result to shared memory in F32
3710+ if (sgitg == 0 ) {
3711+ for (short i = 0 ; i < DV8; ++i) {
3712+ // simdgroup_store(lo[i], so + i*8, DV, 0, false);
3713+ simdgroup_float8x8 t (1 .0f );
3714+ simdgroup_multiply (t, lo[i], t);
3715+ simdgroup_store (t, so + i*8 , DV, 0 , false );
37173716 }
3717+ }
37183718
3719- threadgroup_barrier (mem_flags::mem_threadgroup);
3719+ threadgroup_barrier (mem_flags::mem_threadgroup);
37203720
3721- // the first simdgroup accumulates the results from the other simdgroups
3722- if (sgitg == 0 ) {
3723- for (short j = 0 ; j < Q; ++j) {
3724- const float S0 = ss[j*TS + 0 ];
3725- const float S1 = ss[j*TS + sg*SH + 0 ];
3721+ // reduce the warps sequentially
3722+ for (ushort sg = 1 ; sg < nsg; ++sg) {
3723+ if (sgitg == sg) {
3724+ for (short j = tiisg; j < Q; j += NW) {
3725+ const float S0 = ss[j*TS - 1 *SH + 0 ];
3726+ const float S1 = ss[j*TS + 0 ];
37263727
3727- const float M0 = ss[j*TS + 1 ];
3728- const float M1 = ss[j*TS + sg*SH + 1 ];
3728+ const float M0 = ss[j*TS - 1 *SH + 1 ];
3729+ const float M1 = ss[j*TS + 1 ];
37293730
37303731 const float M = max (M0, M1);
37313732
3732- const float ms0 = exp (M0 - M);
3733- const float ms1 = exp (M1 - M);
3733+ float ms0 = exp (M0 - M);
3734+ float ms1 = exp (M1 - M);
37343735
37353736 const float S = S0*ms0 + S1*ms1;
37363737
3737- if (tiisg == 0 ) {
3738- ss[j*TS + 0 ] = S;
3739- ss[j*TS + 1 ] = M;
3738+ ss[j*TS + 0 ] = S;
3739+ ss[j*TS + 1 ] = M;
37403740
3741- ss[j*TS + 2 *C + j ] = ms0;
3742- ss[j*TS + 2 *C + j + sg*SH] = ms1;
3743- }
3741+ ss[j*TS + 2 *C + j - 1 *SH] = ms0;
3742+ ss[j*TS + 2 *C + j ] = ms1;
37443743 }
37453744
3745+ // simdgroup_barrier(mem_flags::mem_threadgroup);
3746+
37463747 // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
37473748 {
37483749 s8x8_t ms0;
37493750 s8x8_t ms1;
37503751
3751- simdgroup_load (ms0, ss + 2 *C, TS, 0 , false );
3752- simdgroup_load (ms1, ss + 2 *C + sg*SH, TS, 0 , false );
3752+ simdgroup_load (ms0, ss + 2 *C - 1 *SH, TS, 0 , false );
3753+ simdgroup_load (ms1, ss + 2 *C, TS, 0 , false );
37533754
37543755 #pragma unroll(DV8)
37553756 for (short i = 0 ; i < DV8; ++i) {
3756- o8x8_t t;
3757+ simdgroup_float8x8 t;
37573758
37583759 simdgroup_load (t, so + i*8 , DV, 0 , false );
3759- simdgroup_multiply (t, ms1 , t);
3760+ simdgroup_multiply (t, ms0 , t);
37603761
3761- simdgroup_multiply_accumulate (lo[i], ms0, lo[i], t);
3762+ simdgroup_multiply_accumulate (t, ms1, lo[i], t);
3763+ simdgroup_store (t, so + i*8 , DV, 0 , false );
37623764 }
37633765 }
37643766 }
3765- }
37663767
3767- // store result to shared memory (reuse sq)
3768- if (sgitg == 0 ) {
3769- for (short i = 0 ; i < DV8; ++i) {
3770- simdgroup_store (lo[i], so + i*8 , DV, 0 , false );
3771- }
3768+ threadgroup_barrier (mem_flags::mem_threadgroup);
37723769 }
37733770
3774- threadgroup_barrier (mem_flags::mem_threadgroup);
3775-
3776- threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2 *Q*DK);
3771+ threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2 *(nsg-1 )*SH + 2 *Q*DK);
37773772
37783773 // final rescale with 1/S and store to global memory
37793774 for (short j = sgitg; j < Q && iq1 + j < args.ne01 ; j += nsg) {
@@ -3796,17 +3791,17 @@ kernel void kernel_flash_attn_ext(
37963791 half, half4x4, simdgroup_half8x8, \
37973792 float , simdgroup_float8x8, \
37983793 float , simdgroup_float8x8, \
3799- float , float4 , simdgroup_float8x8
3800- // half , half4 , simdgroup_half8x8
3794+ half , half4 , simdgroup_half8x8
3795+ // float , float4 , simdgroup_float8x8
38013796
38023797#define FA_TYPES_BF \
38033798 bfloat, bfloat4, simdgroup_bfloat8x8, \
38043799 bfloat, bfloat4x4, simdgroup_bfloat8x8, \
38053800 bfloat, bfloat4x4, simdgroup_bfloat8x8, \
38063801 float , simdgroup_float8x8, \
38073802 float , simdgroup_float8x8, \
3808- float , float4 , simdgroup_float8x8
3809- // half , half4 , simdgroup_half8x8
3803+ half , half4 , simdgroup_half8x8
3804+ // float , float4 , simdgroup_float8x8
38103805
38113806typedef decltype (kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 64 , 64 >) flash_attn_ext_t;
38123807
0 commit comments