@@ -3328,14 +3328,14 @@ kernel void kernel_flash_attn_ext(
33283328 constexpr short NW = N_SIMDWIDTH;
33293329 constexpr short SH = (2 *C + Q); // shared memory per simdgroup (s_t == float)
33303330
3331- const short TS = nsg*SH; // shared memory size per query in (s_t == float)
3332- const short T = DK + 2 *TS; // shared memory size per query in (half)
3331+ const short TS = nsg*SH; // shared memory size per query in (s_t == float)
3332+ const short T = 2 * DK + 2 *TS; // shared memory size per query in (half)
33333333
3334- threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0 *DK); // holds the query data
3335- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0 *DK); // same as above but in q4_t
3336- threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0 *DK); // reuse query data for accumulation
3337- threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0 *DK); // same as above but in o4_t
3338- threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2 *sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix
3334+ threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0 *DK); // holds the query data
3335+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0 *DK); // same as above but in q4_t
3336+ threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0 *DK); // reuse query data for accumulation
3337+ threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0 *DK); // same as above but in o4_t
3338+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2 *sgitg*SH + 2 * Q*DK); // scratch buffer for attention, mask and diagonal matrix
33393339
33403340 threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4 *16 *KV) + Q*T); // scratch buffer to load K in shared memory
33413341 threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4 *16 *KV) + Q*T); // same as above but in k4x4_t
@@ -3354,7 +3354,7 @@ kernel void kernel_flash_attn_ext(
33543354 if (iq1 + j < args.ne01 ) {
33553355 sq4[j*DK4 + i] = (q4_t ) q4[i];
33563356 } else {
3357- sq4[j*DK4 + i] = ( q4_t ) 0 . 0f ;
3357+ sq4[j*DK4 + i] = 0 ;
33583358 }
33593359 }
33603360 }
@@ -3634,9 +3634,6 @@ kernel void kernel_flash_attn_ext(
36343634
36353635 // reduce the warps sequentially
36363636 for (ushort sg = 1 ; sg < nsg; ++sg) {
3637- float S = { 0 .0f };
3638- float M = { -__FLT_MAX__/2 };
3639-
36403637 threadgroup_barrier (mem_flags::mem_threadgroup);
36413638
36423639 // each simdgroup stores its output to shared memory, reusing sq
@@ -3657,12 +3654,12 @@ kernel void kernel_flash_attn_ext(
36573654 const float M0 = ss[j*TS + 1 ];
36583655 const float M1 = ss[j*TS + sg*SH + 1 ];
36593656
3660- M = max (M0, M1);
3657+ const float M = max (M0, M1);
36613658
36623659 const float ms0 = exp (M0 - M);
36633660 const float ms1 = exp (M1 - M);
36643661
3665- S = S0*ms0 + S1*ms1;
3662+ const float S = S0*ms0 + S1*ms1;
36663663
36673664 if (tiisg == 0 ) {
36683665 ss[j*TS + 0 ] = S;
@@ -3701,16 +3698,18 @@ kernel void kernel_flash_attn_ext(
37013698 }
37023699 }
37033700
3704- device float4 * dst4 = (device float4 *) dst;
3701+ threadgroup_barrier (mem_flags::mem_threadgroup);
3702+
3703+ threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2 *Q*DK);
37053704
37063705 // final rescale with 1/S and store to global memory
3707- if (sgitg == 0 ) {
3708- for (short j = 0 ; j < Q && iq1 + j < args.ne01 ; ++j) {
3709- const float S = ss[j*TS + 0 ];
3706+ for (short j = sgitg; j < Q && iq1 + j < args.ne01 ; j += nsg) {
3707+ const float S = 1 .0f /sf[j*TS + 0 ];
37103708
3711- for (short i = tiisg; i < DV4; i += NW) {
3712- dst4[((uint64_t )iq3*args.ne2 *args.ne1 + iq2 + (uint64_t )(iq1 + j)*args.ne1 )*DV4 + i] = (float4) so4[j*DV4 + i]/S;
3713- }
3709+ device float4 * dst4 = (device float4 *) dst + ((uint64_t )iq3*args.ne2 *args.ne1 + iq2 + (uint64_t )(iq1 + j)*args.ne1 )*DV4;
3710+
3711+ for (short i = tiisg; i < DV4; i += NW) {
3712+ dst4[i] = (float4) so4[j*DV4 + i]*S;
37143713 }
37153714 }
37163715}
@@ -3719,12 +3718,22 @@ kernel void kernel_flash_attn_ext(
37193718// template to be able to explore different combinations
37203719//
37213720#define FA_TYPES \
3722- half, half4, simdgroup_half8x8, \
3723- half, half4x4, simdgroup_half8x8, \
3724- half, half4x4, simdgroup_half8x8, \
3725- float , simdgroup_float8x8, \
3726- float , simdgroup_float8x8, \
3727- half, half4, simdgroup_half8x8
3721+ float , float4, simdgroup_float8x8, \
3722+ half, half4x4, simdgroup_half8x8, \
3723+ half, half4x4, simdgroup_half8x8, \
3724+ float , simdgroup_float8x8, \
3725+ float , simdgroup_float8x8, \
3726+ float , float4, simdgroup_float8x8
3727+ // half, half4, simdgroup_half8x8
3728+
3729+ #define FA_TYPES_BF \
3730+ bfloat, bfloat4, simdgroup_bfloat8x8, \
3731+ bfloat, bfloat4x4, simdgroup_bfloat8x8, \
3732+ bfloat, bfloat4x4, simdgroup_bfloat8x8, \
3733+ float , simdgroup_float8x8, \
3734+ float , simdgroup_float8x8, \
3735+ float , float4, simdgroup_float8x8
3736+ // half, half4, simdgroup_half8x8
37283737
37293738typedef decltype (kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 64 , 64 >) flash_attn_ext_t;
37303739
@@ -3739,15 +3748,15 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_at
37393748template [[host_name(" kernel_flash_attn_ext_f16_hk576_hv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 576 , 512 >;
37403749
37413750#if defined(GGML_METAL_USE_BF16)
3742- template [[host_name(" kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 64 , 64 >;
3743- template [[host_name(" kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 80 , 80 >;
3744- template [[host_name(" kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 96 , 96 >;
3745- template [[host_name(" kernel_flash_attn_ext_bf16_h112" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 112 , 112 >;
3746- template [[host_name(" kernel_flash_attn_ext_bf16_h128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 128 , 128 >;
3747- template [[host_name(" kernel_flash_attn_ext_bf16_h192" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 192 , 192 >;
3748- template [[host_name(" kernel_flash_attn_ext_bf16_hk192_hv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 192 , 128 >;
3749- template [[host_name(" kernel_flash_attn_ext_bf16_h256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 256 , 256 >;
3750- template [[host_name(" kernel_flash_attn_ext_bf16_hk576_hv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 576 , 512 >;
3751+ template [[host_name(" kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 64 , 64 >;
3752+ template [[host_name(" kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 80 , 80 >;
3753+ template [[host_name(" kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 96 , 96 >;
3754+ template [[host_name(" kernel_flash_attn_ext_bf16_h112" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 112 , 112 >;
3755+ template [[host_name(" kernel_flash_attn_ext_bf16_h128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 128 , 128 >;
3756+ template [[host_name(" kernel_flash_attn_ext_bf16_h192" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 192 , 192 >;
3757+ template [[host_name(" kernel_flash_attn_ext_bf16_hk192_hv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 192 , 128 >;
3758+ template [[host_name(" kernel_flash_attn_ext_bf16_h256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 256 , 256 >;
3759+ template [[host_name(" kernel_flash_attn_ext_bf16_hk576_hv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 576 , 512 >;
37513760#endif
37523761
37533762template [[host_name(" kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 64 , 64 >;
@@ -3801,6 +3810,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_at
38013810template [[host_name(" kernel_flash_attn_ext_q8_0_hk576_hv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 576 , 512 >;
38023811
38033812#undef FA_TYPES
3813+ #undef FA_TYPES_BF
38043814
38053815template <
38063816 typename q4_t , // query types in shared memory
@@ -3847,12 +3857,12 @@ kernel void kernel_flash_attn_ext_vec(
38473857
38483858 const short T = DK + nsg*SH; // shared memory size per query in (half)
38493859
3850- // threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3851- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0 *DK); // same as above but in q4_t
3852- threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
3853- threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
3854- threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2 *C + Q*DK); // scratch buffer for mask
3855- threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
3860+ // threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3861+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0 *DK); // same as above but in q4_t
3862+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
3863+ threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
3864+ threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2 *C + Q*DK); // scratch buffer for mask
3865+ threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + 2 * sgitg*DV + Q*T); // scratch buffer for the results
38563866
38573867 // store the result for all queries in local memory (the O matrix from the paper)
38583868 o4_t lo[DV4/NL];
@@ -4157,7 +4167,7 @@ kernel void kernel_flash_attn_ext_vec(
41574167 half4, \
41584168 float , \
41594169 float , float4, \
4160- half4
4170+ float4
41614171
41624172typedef decltype (kernel_flash_attn_ext_vec<FA_TYPES, half4, 1 , dequantize_f16_t4, half4, 1 , dequantize_f16_t4, 128 , 128 , 4 >) flash_attn_ext_vec_t;
41634173
0 commit comments