@@ -269,6 +269,12 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
269269 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
270270 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
271271 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
272+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
273+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
274+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
275+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,
276+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,
277+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
272278 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
273279 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
274280 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
@@ -300,12 +306,14 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
300306 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
301307 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
302308 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
309+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
303310 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
304311 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
305312 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
306313 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
307314 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
308315 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
316+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,
309317 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
310318 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
311319 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
@@ -585,6 +593,9 @@ @implementation GGMLMetalClass
585593 struct ggml_metal_kernel * kernel = &ctx->kernels [e]; \
586594 id <MTLFunction > metal_function = [metal_library newFunctionWithName: @" kernel_" #name]; \
587595 kernel->pipeline = [device newComputePipelineStateWithFunction: metal_function error: &error]; \
596+ GGML_LOG_INFO (" %s : loaded %-40s %16p | th_max = %4d | th_width = %4d \n " , __func__, " kernel_" #name, (void *) kernel->pipeline , \
597+ (int ) kernel->pipeline .maxTotalThreadsPerThreadgroup , \
598+ (int ) kernel->pipeline .threadExecutionWidth ); \
588599 [metal_function release ]; \
589600 if (error) { \
590601 GGML_LOG_ERROR (" %s : error: load pipeline error: %s \n " , __func__, [[error description ] UTF8String ]); \
@@ -777,6 +788,12 @@ @implementation GGMLMetalClass
777788 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm);
778789 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm);
779790 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
791+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && has_bfloat);
792+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && has_bfloat);
793+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && has_bfloat);
794+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && has_bfloat);
795+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && has_bfloat);
796+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && has_bfloat);
780797 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
781798 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
782799 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
@@ -808,12 +825,14 @@ @implementation GGMLMetalClass
808825 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm);
809826 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
810827 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
828+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && has_bfloat);
811829 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
812830 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction);
813831 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction);
814832 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction);
815833 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction);
816834 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction);
835+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && has_bfloat);
817836 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction);
818837 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction);
819838 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
@@ -1111,7 +1130,7 @@ static void ggml_metal_encode_node(
11111130 const uint64_t nb20 = src2 ? src2->nb [0 ] : 0 ; GGML_UNUSED (nb20);
11121131 const uint64_t nb21 = src2 ? src2->nb [1 ] : 0 ;
11131132 const uint64_t nb22 = src2 ? src2->nb [2 ] : 0 ;
1114- const uint64_t nb23 = src2 ? src2->nb [3 ] : 0 ;
1133+ const uint64_t nb23 = src2 ? src2->nb [3 ] : 0 ; GGML_UNUSED (nb23);
11151134
11161135 const int64_t ne0 = dst ? dst->ne [0 ] : 0 ;
11171136 const int64_t ne1 = dst ? dst->ne [1 ] : 0 ;
@@ -3033,6 +3052,23 @@ static void ggml_metal_encode_node(
30333052 }
30343053 }
30353054 } break ;
3055+ case GGML_TYPE_BF16:
3056+ {
3057+ switch (ne00) {
3058+ case 64 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline ; break ;
3059+ case 80 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline ; break ;
3060+ case 96 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline ; break ;
3061+ case 112 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline ; break ;
3062+ case 128 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline ; break ;
3063+ case 256 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline ; break ;
3064+ default :
3065+ {
3066+ GGML_LOG_ERROR (" unsupported size: %lld \n " , ne00);
3067+ GGML_LOG_ERROR (" add template specialization for this size\n " );
3068+ GGML_ABORT (" add template specialization for this size" );
3069+ }
3070+ }
3071+ } break ;
30363072 case GGML_TYPE_Q4_0:
30373073 {
30383074 switch (ne00) {
@@ -3133,6 +3169,7 @@ static void ggml_metal_encode_node(
31333169 {
31343170 switch (src1->type ) {
31353171 case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline ; break ;
3172+ case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline ; break ;
31363173 case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline ; break ;
31373174 case GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline ; break ;
31383175 case GGML_TYPE_Q5_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline ; break ;
@@ -3150,6 +3187,7 @@ static void ggml_metal_encode_node(
31503187 {
31513188 switch (src1->type ) {
31523189 case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline ; break ;
3190+ case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline ; break ;
31533191 case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline ; break ;
31543192 case GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline ; break ;
31553193 case GGML_TYPE_Q5_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline ; break ;
@@ -3194,18 +3232,15 @@ static void ggml_metal_encode_node(
31943232 [encoder setBytes: &nb11 length: sizeof (uint64_t ) atIndex: 14 ];
31953233 [encoder setBytes: &nb12 length: sizeof (uint64_t ) atIndex: 15 ];
31963234 [encoder setBytes: &nb13 length: sizeof (uint64_t ) atIndex: 16 ];
3197- [encoder setBytes: &nb21 length: sizeof (uint64_t ) atIndex: 17 ];
3198- [encoder setBytes: &nb22 length: sizeof (uint64_t ) atIndex: 18 ];
3199- [encoder setBytes: &nb23 length: sizeof (uint64_t ) atIndex: 19 ];
3200- [encoder setBytes: &nb31 length: sizeof (uint64_t ) atIndex: 20 ];
3201- [encoder setBytes: &ne1 length: sizeof ( int64_t ) atIndex: 21 ];
3202- [encoder setBytes: &ne2 length: sizeof ( int64_t ) atIndex: 22 ];
3203- [encoder setBytes: &scale length: sizeof ( float ) atIndex: 23 ];
3204- [encoder setBytes: &max_bias length: sizeof ( float ) atIndex: 24 ];
3205- [encoder setBytes: &m0 length: sizeof (m0) atIndex: 25 ];
3206- [encoder setBytes: &m1 length: sizeof (m1) atIndex: 26 ];
3207- [encoder setBytes: &n_head_log2 length: sizeof (n_head_log2) atIndex: 27 ];
3208- [encoder setBytes: &logit_softcap length: sizeof (logit_softcap) atIndex: 28 ];
3235+ [encoder setBytes: &nb31 length: sizeof (uint64_t ) atIndex: 17 ];
3236+ [encoder setBytes: &ne1 length: sizeof ( int64_t ) atIndex: 18 ];
3237+ [encoder setBytes: &ne2 length: sizeof ( int64_t ) atIndex: 19 ];
3238+ [encoder setBytes: &scale length: sizeof ( float ) atIndex: 20 ];
3239+ [encoder setBytes: &max_bias length: sizeof ( float ) atIndex: 21 ];
3240+ [encoder setBytes: &m0 length: sizeof (m0) atIndex: 22 ];
3241+ [encoder setBytes: &m1 length: sizeof (m1) atIndex: 23 ];
3242+ [encoder setBytes: &n_head_log2 length: sizeof (n_head_log2) atIndex: 24 ];
3243+ [encoder setBytes: &logit_softcap length: sizeof (logit_softcap) atIndex: 25 ];
32093244
32103245 if (!use_vec_kernel) {
32113246 // half8x8 kernel
@@ -3216,11 +3251,14 @@ static void ggml_metal_encode_node(
32163251 GGML_ASSERT (nqptg % 8 == 0 );
32173252 GGML_ASSERT (ncpsg % 32 == 0 );
32183253
3254+ // 2*(2*ncpsg + nqptg)*(nsg)
3255+ // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
3256+ //
32193257 // 16*32*(nsg)
32203258 // the shared memory needed for the simdgroups to load the KV cache
32213259 // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
32223260 //
3223- #define FATTN_SMEM (nsg ) (GGML_PAD((nqptg*(ne00 + 2 *(ncpsg + nqptg)*(nsg)) + 16 *32 *(nsg))*(sizeof (float )/2 ), 16 ))
3261+ #define FATTN_SMEM (nsg ) (GGML_PAD((nqptg*(ne00 + 2 *(2 * ncpsg + nqptg)*(nsg)) + 16 *32 *(nsg))*(sizeof (float )/2 ), 16 ))
32243262
32253263 int64_t nsgmax = 2 ;
32263264
@@ -3254,12 +3292,12 @@ static void ggml_metal_encode_node(
32543292
32553293 // ne00 + 2*ncpsg*(nsg)
32563294 // for each query, we load it as f16 in shared memory (ne00)
3257- // and store the attention scores (nqptg x ncpsg) as f32
3295+ // and store the soft_max values and the mask
32583296 //
3259- // 2* ne00*(nsg)
3260- // each simdgroup has a full f32 head vector in shared mem to accumulate results
3297+ // ne00*(nsg)
3298+ // each simdgroup has a full f16 head vector in shared mem to accumulate results
32613299 //
3262- #define FATTN_SMEM (nsg ) (GGML_PAD((nqptg*(ne00 + 2 *ncpsg*(nsg)) + 2 * ne00*(nsg))*(sizeof (float )/2 ), 16 ))
3300+ #define FATTN_SMEM (nsg ) (GGML_PAD((nqptg*(ne00 + 2 *ncpsg*(nsg)) + ne00*(nsg))*(sizeof (float )/2 ), 16 ))
32633301
32643302 int64_t nsgmax = 2 ;
32653303
0 commit comments