@@ -3328,14 +3328,14 @@ kernel void kernel_flash_attn_ext(
33283328    constexpr  short  NW  = N_SIMDWIDTH;
33293329    constexpr  short  SH  = (2 *C + Q); //  shared memory per simdgroup (s_t == float)
33303330
3331-     const  short  TS = nsg*SH;   //  shared memory size per query in (s_t == float)
3332-     const  short  T  = DK + 2 *TS; //  shared memory size per query in (half)
3331+     const  short  TS = nsg*SH;       //  shared memory size per query in (s_t == float)
3332+     const  short  T  = 2 * DK + 2 *TS; //  shared memory size per query in (half)
33333333
3334-     threadgroup q_t   * sq  = (threadgroup q_t   *) (shmem_f16 +              0 *DK); //  holds the query data
3335-     threadgroup q4_t  * sq4 = (threadgroup q4_t  *) (shmem_f16 +              0 *DK); //  same as above but in q4_t
3336-     threadgroup o_t   * so  = (threadgroup o_t   *) (shmem_f16 +              0 *DK); //  reuse query data for accumulation
3337-     threadgroup o4_t  * so4 = (threadgroup o4_t  *) (shmem_f16 +              0 *DK); //  same as above but in o4_t
3338-     threadgroup s_t   * ss  = (threadgroup s_t   *) (shmem_f16 + 2 *sgitg*SH + Q*DK); //  scratch buffer for attention, mask and diagonal matrix
3334+     threadgroup q_t   * sq  = (threadgroup q_t   *) (shmem_f16 +                 0 *DK); //  holds the query data
3335+     threadgroup q4_t  * sq4 = (threadgroup q4_t  *) (shmem_f16 +                 0 *DK); //  same as above but in q4_t
3336+     threadgroup o_t   * so  = (threadgroup o_t   *) (shmem_f16 +                 0 *DK); //  reuse query data for accumulation
3337+     threadgroup o4_t  * so4 = (threadgroup o4_t  *) (shmem_f16 +                 0 *DK); //  same as above but in o4_t
3338+     threadgroup s_t   * ss  = (threadgroup s_t   *) (shmem_f16 + 2 *sgitg*SH + 2 * Q*DK); //  scratch buffer for attention, mask and diagonal matrix
33393339
33403340    threadgroup k_t     * sk    = (threadgroup k_t     *) (shmem_f16 + sgitg*(4 *16 *KV) + Q*T); //  scratch buffer to load K in shared memory
33413341    threadgroup k4x4_t  * sk4x4 = (threadgroup k4x4_t  *) (shmem_f16 + sgitg*(4 *16 *KV) + Q*T); //  same as above but in k4x4_t
@@ -3354,7 +3354,7 @@ kernel void kernel_flash_attn_ext(
33543354            if  (iq1 + j < args.ne01 ) {
33553355                sq4[j*DK4 + i] = (q4_t ) q4[i];
33563356            } else  {
3357-                 sq4[j*DK4 + i] = ( q4_t )  0 . 0f ;
3357+                 sq4[j*DK4 + i] = 0 ;
33583358            }
33593359        }
33603360    }
@@ -3634,9 +3634,6 @@ kernel void kernel_flash_attn_ext(
36343634
36353635    //  reduce the warps sequentially
36363636    for  (ushort sg = 1 ; sg < nsg; ++sg) {
3637-         float  S = { 0 .0f  };
3638-         float  M = { -__FLT_MAX__/2  };
3639- 
36403637        threadgroup_barrier (mem_flags::mem_threadgroup);
36413638
36423639        //  each simdgroup stores its output to shared memory, reusing sq
@@ -3657,12 +3654,12 @@ kernel void kernel_flash_attn_ext(
36573654                const  float  M0 = ss[j*TS +         1 ];
36583655                const  float  M1 = ss[j*TS + sg*SH + 1 ];
36593656
3660-                 M = max (M0, M1);
3657+                 const   float   M = max (M0, M1);
36613658
36623659                const  float  ms0 = exp (M0 - M);
36633660                const  float  ms1 = exp (M1 - M);
36643661
3665-                 S = S0*ms0 + S1*ms1;
3662+                 const   float   S = S0*ms0 + S1*ms1;
36663663
36673664                if  (tiisg == 0 ) {
36683665                    ss[j*TS + 0 ] = S;
@@ -3701,16 +3698,18 @@ kernel void kernel_flash_attn_ext(
37013698        }
37023699    }
37033700
3704-     device float4 * dst4 = (device float4 *) dst;
3701+     threadgroup_barrier (mem_flags::mem_threadgroup);
3702+ 
3703+     threadgroup s_t  * sf = (threadgroup s_t  *) (shmem_f16 + 2 *Q*DK);
37053704
37063705    //  final rescale with 1/S and store to global memory
3707-     if  (sgitg == 0 ) {
3708-         for  (short  j = 0 ; j < Q && iq1 + j < args.ne01 ; ++j) {
3709-             const  float  S = ss[j*TS + 0 ];
3706+     for  (short  j = sgitg; j < Q && iq1 + j < args.ne01 ; j += nsg) {
3707+         const  float  S = 1 .0f /sf[j*TS + 0 ];
37103708
3711-             for  (short  i = tiisg; i < DV4; i += NW) {
3712-                 dst4[((uint64_t )iq3*args.ne2 *args.ne1  + iq2 + (uint64_t )(iq1 + j)*args.ne1 )*DV4 + i] = (float4) so4[j*DV4 + i]/S;
3713-             }
3709+         device float4 * dst4 = (device float4 *) dst + ((uint64_t )iq3*args.ne2 *args.ne1  + iq2 + (uint64_t )(iq1 + j)*args.ne1 )*DV4;
3710+ 
3711+         for  (short  i = tiisg; i < DV4; i += NW) {
3712+             dst4[i] = (float4) so4[j*DV4 + i]*S;
37143713        }
37153714    }
37163715}
@@ -3719,12 +3718,22 @@ kernel void kernel_flash_attn_ext(
37193718//        template to be able to explore different combinations
37203719// 
37213720#define  FA_TYPES  \
3722-     half,  half4,   simdgroup_half8x8,  \
3723-     half,  half4x4, simdgroup_half8x8,  \
3724-     half,  half4x4, simdgroup_half8x8,  \
3725-     float ,          simdgroup_float8x8, \
3726-     float ,          simdgroup_float8x8, \
3727-     half,  half4,   simdgroup_half8x8
3721+     float ,  float4,    simdgroup_float8x8, \
3722+     half,   half4x4,   simdgroup_half8x8,  \
3723+     half,   half4x4,   simdgroup_half8x8,  \
3724+     float ,             simdgroup_float8x8, \
3725+     float ,             simdgroup_float8x8, \
3726+     float ,  float4,    simdgroup_float8x8
3727+     // half,   half4,     simdgroup_half8x8
3728+ 
3729+ #define  FA_TYPES_BF  \
3730+     bfloat, bfloat4,   simdgroup_bfloat8x8, \
3731+     bfloat, bfloat4x4, simdgroup_bfloat8x8, \
3732+     bfloat, bfloat4x4, simdgroup_bfloat8x8, \
3733+     float ,             simdgroup_float8x8,  \
3734+     float ,             simdgroup_float8x8,  \
3735+     float ,  float4,    simdgroup_float8x8
3736+     // half,   half4,     simdgroup_half8x8
37283737
37293738typedef  decltype (kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 64 , 64 >) flash_attn_ext_t;
37303739
@@ -3739,15 +3748,15 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]]         kernel flash_at
37393748template  [[host_name(" kernel_flash_attn_ext_f16_hk576_hv512" flash_attn_ext_t  kernel_flash_attn_ext<FA_TYPES, half4x4,    1 , dequantize_f16,  half4x4,    1 , dequantize_f16,  576 , 512 >;
37403749
37413750#if  defined(GGML_METAL_USE_BF16)
3742- template  [[host_name(" kernel_flash_attn_ext_bf16_h64" flash_attn_ext_t  kernel_flash_attn_ext<FA_TYPES , bfloat4x4,  1 , dequantize_bf16, bfloat4x4,  1 , dequantize_bf16, 64 ,  64 >;
3743- template  [[host_name(" kernel_flash_attn_ext_bf16_h80" flash_attn_ext_t  kernel_flash_attn_ext<FA_TYPES , bfloat4x4,  1 , dequantize_bf16, bfloat4x4,  1 , dequantize_bf16, 80 ,  80 >;
3744- template  [[host_name(" kernel_flash_attn_ext_bf16_h96" flash_attn_ext_t  kernel_flash_attn_ext<FA_TYPES , bfloat4x4,  1 , dequantize_bf16, bfloat4x4,  1 , dequantize_bf16, 96 ,  96 >;
3745- template  [[host_name(" kernel_flash_attn_ext_bf16_h112" flash_attn_ext_t  kernel_flash_attn_ext<FA_TYPES , bfloat4x4,  1 , dequantize_bf16, bfloat4x4,  1 , dequantize_bf16, 112 , 112 >;
3746- template  [[host_name(" kernel_flash_attn_ext_bf16_h128" flash_attn_ext_t  kernel_flash_attn_ext<FA_TYPES , bfloat4x4,  1 , dequantize_bf16, bfloat4x4,  1 , dequantize_bf16, 128 , 128 >;
3747- template  [[host_name(" kernel_flash_attn_ext_bf16_h192" flash_attn_ext_t  kernel_flash_attn_ext<FA_TYPES , bfloat4x4,  1 , dequantize_bf16, bfloat4x4,  1 , dequantize_bf16, 192 , 192 >;
3748- template  [[host_name(" kernel_flash_attn_ext_bf16_hk192_hv128" flash_attn_ext_t  kernel_flash_attn_ext<FA_TYPES , bfloat4x4,  1 , dequantize_bf16, bfloat4x4,  1 , dequantize_bf16, 192 , 128 >;
3749- template  [[host_name(" kernel_flash_attn_ext_bf16_h256" flash_attn_ext_t  kernel_flash_attn_ext<FA_TYPES , bfloat4x4,  1 , dequantize_bf16, bfloat4x4,  1 , dequantize_bf16, 256 , 256 >;
3750- template  [[host_name(" kernel_flash_attn_ext_bf16_hk576_hv512" flash_attn_ext_t  kernel_flash_attn_ext<FA_TYPES , bfloat4x4,  1 , dequantize_bf16, bfloat4x4,  1 , dequantize_bf16, 576 , 512 >;
3751+ template  [[host_name(" kernel_flash_attn_ext_bf16_h64" flash_attn_ext_t  kernel_flash_attn_ext<FA_TYPES_BF , bfloat4x4,  1 , dequantize_bf16, bfloat4x4,  1 , dequantize_bf16, 64 ,  64 >;
3752+ template  [[host_name(" kernel_flash_attn_ext_bf16_h80" flash_attn_ext_t  kernel_flash_attn_ext<FA_TYPES_BF , bfloat4x4,  1 , dequantize_bf16, bfloat4x4,  1 , dequantize_bf16, 80 ,  80 >;
3753+ template  [[host_name(" kernel_flash_attn_ext_bf16_h96" flash_attn_ext_t  kernel_flash_attn_ext<FA_TYPES_BF , bfloat4x4,  1 , dequantize_bf16, bfloat4x4,  1 , dequantize_bf16, 96 ,  96 >;
3754+ template  [[host_name(" kernel_flash_attn_ext_bf16_h112" flash_attn_ext_t  kernel_flash_attn_ext<FA_TYPES_BF , bfloat4x4,  1 , dequantize_bf16, bfloat4x4,  1 , dequantize_bf16, 112 , 112 >;
3755+ template  [[host_name(" kernel_flash_attn_ext_bf16_h128" flash_attn_ext_t  kernel_flash_attn_ext<FA_TYPES_BF , bfloat4x4,  1 , dequantize_bf16, bfloat4x4,  1 , dequantize_bf16, 128 , 128 >;
3756+ template  [[host_name(" kernel_flash_attn_ext_bf16_h192" flash_attn_ext_t  kernel_flash_attn_ext<FA_TYPES_BF , bfloat4x4,  1 , dequantize_bf16, bfloat4x4,  1 , dequantize_bf16, 192 , 192 >;
3757+ template  [[host_name(" kernel_flash_attn_ext_bf16_hk192_hv128" flash_attn_ext_t  kernel_flash_attn_ext<FA_TYPES_BF , bfloat4x4,  1 , dequantize_bf16, bfloat4x4,  1 , dequantize_bf16, 192 , 128 >;
3758+ template  [[host_name(" kernel_flash_attn_ext_bf16_h256" flash_attn_ext_t  kernel_flash_attn_ext<FA_TYPES_BF , bfloat4x4,  1 , dequantize_bf16, bfloat4x4,  1 , dequantize_bf16, 256 , 256 >;
3759+ template  [[host_name(" kernel_flash_attn_ext_bf16_hk576_hv512" flash_attn_ext_t  kernel_flash_attn_ext<FA_TYPES_BF , bfloat4x4,  1 , dequantize_bf16, bfloat4x4,  1 , dequantize_bf16, 576 , 512 >;
37513760#endif 
37523761
37533762template  [[host_name(" kernel_flash_attn_ext_q4_0_h64" flash_attn_ext_t  kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 64 ,  64 >;
@@ -3801,6 +3810,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]]        kernel flash_at
38013810template  [[host_name(" kernel_flash_attn_ext_q8_0_hk576_hv512" flash_attn_ext_t  kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 576 , 512 >;
38023811
38033812#undef  FA_TYPES
3813+ #undef  FA_TYPES_BF
38043814
38053815template <
38063816    typename  q4_t ,  //  query types in shared memory
@@ -3847,12 +3857,12 @@ kernel void kernel_flash_attn_ext_vec(
38473857
38483858    const  short  T = DK + nsg*SH; //  shared memory size per query in (half)
38493859
3850-   // threadgroup q_t   * sq  = (threadgroup q_t   *) (shmem_f16 +                  0*DK); // holds the query data
3851-     threadgroup q4_t   * sq4 = (threadgroup q4_t   *) (shmem_f16 +                  0 *DK); //  same as above but in q4_t
3852-     threadgroup s_t    * ss  = (threadgroup s_t    *) (shmem_f16 + sgitg*SH       + Q*DK); //  scratch buffer for attention
3853-     threadgroup s4_t   * ss4 = (threadgroup s4_t   *) (shmem_f16 + sgitg*SH       + Q*DK); //  same as above but in s4_t
3854-     threadgroup float  * sm  = (threadgroup float  *) (shmem_f16 + sgitg*SH + 2 *C + Q*DK); //  scratch buffer for mask
3855-     threadgroup o4_t   * sr4 = (threadgroup o4_t   *) (shmem_f16 + sgitg*DV       + Q*T);  //  scratch buffer for the results
3860+   // threadgroup q_t   * sq  = (threadgroup q_t   *) (shmem_f16 +                     0*DK); // holds the query data
3861+     threadgroup q4_t   * sq4 = (threadgroup q4_t   *) (shmem_f16 +                     0 *DK); //  same as above but in q4_t
3862+     threadgroup s_t    * ss  = (threadgroup s_t    *) (shmem_f16 +    sgitg*SH       + Q*DK); //  scratch buffer for attention
3863+     threadgroup s4_t   * ss4 = (threadgroup s4_t   *) (shmem_f16 +    sgitg*SH       + Q*DK); //  same as above but in s4_t
3864+     threadgroup float  * sm  = (threadgroup float  *) (shmem_f16 +    sgitg*SH + 2 *C + Q*DK); //  scratch buffer for mask
3865+     threadgroup o4_t   * sr4 = (threadgroup o4_t   *) (shmem_f16 + 2 * sgitg*DV       + Q*T);  //  scratch buffer for the results
38563866
38573867    //  store the result for all queries in local memory (the O matrix from the paper)
38583868    o4_t  lo[DV4/NL];
@@ -4157,7 +4167,7 @@ kernel void kernel_flash_attn_ext_vec(
41574167           half4,  \
41584168    float ,         \
41594169    float , float4, \
4160-            half4 
4170+            float4 
41614171
41624172typedef  decltype (kernel_flash_attn_ext_vec<FA_TYPES, half4, 1 , dequantize_f16_t4, half4, 1 , dequantize_f16_t4, 128 , 128 , 4 >) flash_attn_ext_vec_t;
41634173
0 commit comments