3939 bool has_simdgroup_reduction;
4040 bool has_simdgroup_mm;
4141 bool has_bfloat;
42+ bool use_bfloat;
4243
4344 char name[128 ];
4445} g_ggml_ctx_dev_main = {
4748 /* .has_simdgroup_reduction =*/ false ,
4849 /* .has_simdgroup_mm =*/ false ,
4950 /* .has_bfloat =*/ false ,
51+ /* .use_bfloat =*/ false ,
5052 /* .name =*/ " " ,
5153};
5254
6567 ctx->has_bfloat = [ctx->mtl_device supportsFamily: MTLGPUFamilyMetal3_GGML];
6668 ctx->has_bfloat |= [ctx->mtl_device supportsFamily: MTLGPUFamilyApple6];
6769
70+ #if defined(GGML_METAL_USE_BF16)
71+ ctx->use_bfloat = ctx->has_bfloat ;
72+ #else
73+ ctx->use_bfloat = false ;
74+ #endif
75+
6876 strncpy (ctx->name , [[ctx->mtl_device name ] UTF8String ], sizeof (ctx->name ) - 1 );
6977 }
7078
@@ -504,6 +512,10 @@ @implementation GGMLMetalClass
504512 // dictionary of preprocessor macros
505513 NSMutableDictionary * prep = [NSMutableDictionary dictionary ];
506514
515+ if (ctx_dev->use_bfloat ) {
516+ [prep setObject: @" 1" forKey: @" GGML_METAL_USE_BF16" ];
517+ }
518+
507519 MTLCompileOptions * options = [MTLCompileOptions new ];
508520 options.preprocessorMacros = prep;
509521
@@ -556,7 +568,8 @@ @implementation GGMLMetalClass
556568
557569 GGML_LOG_INFO (" %s : simdgroup reduction = %s \n " , __func__, ctx_dev->has_simdgroup_reduction ? " true" : " false" );
558570 GGML_LOG_INFO (" %s : simdgroup matrix mul. = %s \n " , __func__, ctx_dev->has_simdgroup_mm ? " true" : " false" );
559- GGML_LOG_INFO (" %s : bfloat = %s \n " , __func__, ctx_dev->has_bfloat ? " true" : " false" );
571+ GGML_LOG_INFO (" %s : has bfloat = %s \n " , __func__, ctx_dev->has_bfloat ? " true" : " false" );
572+ GGML_LOG_INFO (" %s : use bfloat = %s \n " , __func__, ctx_dev->use_bfloat ? " true" : " false" );
560573 GGML_LOG_INFO (" %s : hasUnifiedMemory = %s \n " , __func__, ctx_dev->mtl_device .hasUnifiedMemory ? " true" : " false" );
561574
562575 ctx->capture_next_compute = false ;
@@ -608,7 +621,7 @@ @implementation GGMLMetalClass
608621
609622 const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm ;
610623 const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction ;
611- const bool has_bfloat = ctx_dev->has_bfloat ;
624+ const bool use_bfloat = ctx_dev->use_bfloat ;
612625
613626 // simd_sum and simd_max requires MTLGPUFamilyApple7
614627
@@ -644,7 +657,7 @@ @implementation GGMLMetalClass
644657 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true );
645658 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true );
646659 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true );
647- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, has_bfloat );
660+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, use_bfloat );
648661 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true );
649662 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true );
650663 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true );
@@ -671,10 +684,10 @@ @implementation GGMLMetalClass
671684 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true );
672685 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true );
673686 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
674- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && has_bfloat );
675- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && has_bfloat );
676- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && has_bfloat );
677- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && has_bfloat );
687+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat );
688+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat );
689+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat );
690+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat );
678691 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
679692 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
680693 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
@@ -703,7 +716,7 @@ @implementation GGMLMetalClass
703716 // GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, has_simdgroup_reduction);
704717 // GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, has_simdgroup_reduction);
705718 // GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, has_simdgroup_reduction);
706- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && has_bfloat );
719+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && use_bfloat );
707720 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, has_simdgroup_reduction);
708721 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, has_simdgroup_reduction);
709722 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
@@ -725,7 +738,7 @@ @implementation GGMLMetalClass
725738 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction);
726739 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm);
727740 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm);
728- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && has_bfloat );
741+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat );
729742 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, has_simdgroup_mm);
730743 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, has_simdgroup_mm);
731744 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
@@ -747,7 +760,7 @@ @implementation GGMLMetalClass
747760 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
748761 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm);
749762 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm);
750- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && has_bfloat );
763+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && use_bfloat );
751764 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm);
752765 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm);
753766 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm);
@@ -840,11 +853,11 @@ @implementation GGMLMetalClass
840853 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
841854 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true );
842855 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true );
843- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, has_bfloat );
856+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat );
844857 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true );
845858 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true );
846- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, has_bfloat );
847- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, has_bfloat );
859+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat );
860+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat );
848861 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true );
849862 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true );
850863 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true );
@@ -936,9 +949,9 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
936949static bool ggml_metal_supports_op (const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
937950 const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm ;
938951 const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction ;
939- const bool has_bfloat = ctx_dev->has_bfloat ;
952+ const bool use_bfloat = ctx_dev->use_bfloat ;
940953
941- if (!has_bfloat ) {
954+ if (!use_bfloat ) {
942955 for (size_t i = 0 , n = 3 ; i < n; ++i) {
943956 if (op->src [i] != NULL && op->src [i]->type == GGML_TYPE_BF16) {
944957 return false ;
0 commit comments