3939 bool has_simdgroup_reduction;
4040 bool has_simdgroup_mm;
4141 bool has_bfloat;
42+ bool use_bfloat;
4243
4344 char name[128 ];
4445} g_ggml_ctx_dev_main = {
4748 /* .has_simdgroup_reduction =*/ false ,
4849 /* .has_simdgroup_mm =*/ false ,
4950 /* .has_bfloat =*/ false ,
51+ /* .use_bfloat =*/ false ,
5052 /* .name =*/ " " ,
5153};
5254
6567 ctx->has_bfloat = [ctx->mtl_device supportsFamily: MTLGPUFamilyMetal3_GGML];
6668 ctx->has_bfloat |= [ctx->mtl_device supportsFamily: MTLGPUFamilyApple6];
6769
70+ #if defined(GGML_METAL_USE_BF16)
71+ ctx->use_bfloat = ctx->has_bfloat ;
72+ #else
73+ ctx->use_bfloat = false ;
74+ #endif
75+
6876 strncpy (ctx->name , [[ctx->mtl_device name ] UTF8String ], sizeof (ctx->name ) - 1 );
6977 }
7078
@@ -496,6 +504,10 @@ @implementation GGMLMetalClass
496504 // dictionary of preprocessor macros
497505 NSMutableDictionary * prep = [NSMutableDictionary dictionary ];
498506
507+ if (ctx_dev->use_bfloat ) {
508+ [prep setObject: @" 1" forKey: @" GGML_METAL_USE_BF16" ];
509+ }
510+
499511 MTLCompileOptions * options = [MTLCompileOptions new ];
500512 options.preprocessorMacros = prep;
501513
@@ -548,7 +560,8 @@ @implementation GGMLMetalClass
548560
549561 GGML_LOG_INFO (" %s : simdgroup reduction = %s \n " , __func__, ctx_dev->has_simdgroup_reduction ? " true" : " false" );
550562 GGML_LOG_INFO (" %s : simdgroup matrix mul. = %s \n " , __func__, ctx_dev->has_simdgroup_mm ? " true" : " false" );
551- GGML_LOG_INFO (" %s : bfloat = %s \n " , __func__, ctx_dev->has_bfloat ? " true" : " false" );
563+ GGML_LOG_INFO (" %s : has bfloat = %s \n " , __func__, ctx_dev->has_bfloat ? " true" : " false" );
564+ GGML_LOG_INFO (" %s : use bfloat = %s \n " , __func__, ctx_dev->use_bfloat ? " true" : " false" );
552565 GGML_LOG_INFO (" %s : hasUnifiedMemory = %s \n " , __func__, ctx_dev->mtl_device .hasUnifiedMemory ? " true" : " false" );
553566
554567 ctx->capture_next_compute = false ;
@@ -597,7 +610,7 @@ @implementation GGMLMetalClass
597610
598611 const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm ;
599612 const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction ;
600- const bool has_bfloat = ctx_dev->has_bfloat ;
613+ const bool use_bfloat = ctx_dev->use_bfloat ;
601614
602615 // simd_sum and simd_max requires MTLGPUFamilyApple7
603616
@@ -633,7 +646,7 @@ @implementation GGMLMetalClass
633646 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true );
634647 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true );
635648 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true );
636- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, has_bfloat );
649+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, use_bfloat );
637650 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true );
638651 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true );
639652 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true );
@@ -660,10 +673,10 @@ @implementation GGMLMetalClass
660673 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true );
661674 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true );
662675 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
663- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && has_bfloat );
664- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && has_bfloat );
665- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && has_bfloat );
666- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && has_bfloat );
676+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat );
677+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat );
678+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat );
679+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat );
667680 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
668681 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
669682 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
@@ -692,7 +705,7 @@ @implementation GGMLMetalClass
692705 // GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, has_simdgroup_reduction);
693706 // GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, has_simdgroup_reduction);
694707 // GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, has_simdgroup_reduction);
695- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && has_bfloat );
708+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && use_bfloat );
696709 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, has_simdgroup_reduction);
697710 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, has_simdgroup_reduction);
698711 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
@@ -714,7 +727,7 @@ @implementation GGMLMetalClass
714727 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction);
715728 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm);
716729 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm);
717- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && has_bfloat );
730+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat );
718731 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, has_simdgroup_mm);
719732 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, has_simdgroup_mm);
720733 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
@@ -736,7 +749,7 @@ @implementation GGMLMetalClass
736749 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
737750 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm);
738751 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm);
739- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && has_bfloat );
752+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && use_bfloat );
740753 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm);
741754 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm);
742755 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm);
@@ -821,11 +834,11 @@ @implementation GGMLMetalClass
821834 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
822835 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true );
823836 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true );
824- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, has_bfloat );
837+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat );
825838 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true );
826839 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true );
827- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, has_bfloat );
828- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, has_bfloat );
840+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat );
841+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat );
829842 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true );
830843 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true );
831844 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true );
@@ -917,9 +930,9 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
917930static bool ggml_metal_supports_op (const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
918931 const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm ;
919932 const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction ;
920- const bool has_bfloat = ctx_dev->has_bfloat ;
933+ const bool use_bfloat = ctx_dev->use_bfloat ;
921934
922- if (!has_bfloat ) {
935+ if (!use_bfloat ) {
923936 for (size_t i = 0 , n = 3 ; i < n; ++i) {
924937 if (op->src [i] != NULL && op->src [i]->type == GGML_TYPE_BF16) {
925938 return false ;
0 commit comments