3838
3939 bool support_simdgroup_reduction;
4040 bool support_simdgroup_mm;
41+ bool support_bfloat;
4142
4243 char name[128 ];
4344} g_ggml_ctx_dev_main = {
4445 /* .mtl_device =*/ nil ,
4546 /* .mtl_device_ref_count =*/ 0 ,
4647 /* .support_simdgroup_reduction =*/ false ,
4748 /* .support_simdgroup_mm =*/ false ,
49+ /* .support_bfloat =*/ false ,
4850 /* .name =*/ " " ,
4951};
5052
6062
6163 ctx->support_simdgroup_mm = [ctx->mtl_device supportsFamily: MTLGPUFamilyApple7];
6264
65+ ctx->support_bfloat = [ctx->mtl_device supportsFamily: MTLGPUFamilyMetal3_GGML];
66+ ctx->support_bfloat |= [ctx->mtl_device supportsFamily: MTLGPUFamilyApple6];
67+
6368 strncpy (ctx->name , [[ctx->mtl_device name ] UTF8String ], sizeof (ctx->name ) - 1 );
6469 }
6570
@@ -541,9 +546,10 @@ @implementation GGMLMetalClass
541546 }
542547 }
543548
544- GGML_LOG_INFO (" %s : simdgroup reduction support = %s \n " , __func__, ctx_dev->support_simdgroup_reduction ? " true" : " false" );
545- GGML_LOG_INFO (" %s : simdgroup matrix mul. support = %s \n " , __func__, ctx_dev->support_simdgroup_mm ? " true" : " false" );
546- GGML_LOG_INFO (" %s : hasUnifiedMemory = %s \n " , __func__, ctx_dev->mtl_device .hasUnifiedMemory ? " true" : " false" );
549+ GGML_LOG_INFO (" %s : simdgroup reduction = %s \n " , __func__, ctx_dev->support_simdgroup_reduction ? " true" : " false" );
550+ GGML_LOG_INFO (" %s : simdgroup matrix mul. = %s \n " , __func__, ctx_dev->support_simdgroup_mm ? " true" : " false" );
551+ GGML_LOG_INFO (" %s : bfloat = %s \n " , __func__, ctx_dev->support_bfloat ? " true" : " false" );
552+ GGML_LOG_INFO (" %s : hasUnifiedMemory = %s \n " , __func__, ctx_dev->mtl_device .hasUnifiedMemory ? " true" : " false" );
547553
548554 ctx->capture_next_compute = false ;
549555 ctx->capture_started = false ;
@@ -591,6 +597,7 @@ @implementation GGMLMetalClass
591597
592598 const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm ;
593599 const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction ;
600+ const bool support_bfloat = ctx_dev->support_bfloat ;
594601
595602 // simd_sum and simd_max requires MTLGPUFamilyApple7
596603
@@ -626,7 +633,7 @@ @implementation GGMLMetalClass
626633 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true );
627634 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true );
628635 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true );
629- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, true );
636+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, support_bfloat );
630637 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true );
631638 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true );
632639 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true );
@@ -653,10 +660,10 @@ @implementation GGMLMetalClass
653660 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true );
654661 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true );
655662 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, support_simdgroup_reduction);
656- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, support_simdgroup_reduction);
657- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, support_simdgroup_reduction);
658- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, support_simdgroup_reduction);
659- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, support_simdgroup_reduction);
663+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, support_simdgroup_reduction && support_bfloat );
664+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, support_simdgroup_reduction && support_bfloat );
665+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, support_simdgroup_reduction && support_bfloat );
666+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, support_simdgroup_reduction && support_bfloat );
660667 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, support_simdgroup_reduction);
661668 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, support_simdgroup_reduction);
662669 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, support_simdgroup_reduction);
@@ -685,7 +692,7 @@ @implementation GGMLMetalClass
685692 // GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, support_simdgroup_reduction);
686693 // GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, support_simdgroup_reduction);
687694 // GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, support_simdgroup_reduction);
688- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, support_simdgroup_reduction);
695+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, support_simdgroup_reduction && support_bfloat );
689696 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, support_simdgroup_reduction);
690697 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, support_simdgroup_reduction);
691698 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, support_simdgroup_reduction);
@@ -707,7 +714,7 @@ @implementation GGMLMetalClass
707714 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, support_simdgroup_reduction);
708715 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, support_simdgroup_mm);
709716 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, support_simdgroup_mm);
710- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, support_simdgroup_mm);
717+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, support_simdgroup_mm && support_bfloat );
711718 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, support_simdgroup_mm);
712719 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, support_simdgroup_mm);
713720 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, support_simdgroup_mm);
@@ -729,7 +736,7 @@ @implementation GGMLMetalClass
729736 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, support_simdgroup_mm);
730737 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, support_simdgroup_mm);
731738 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, support_simdgroup_mm);
732- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, support_simdgroup_mm);
739+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, support_simdgroup_mm && support_bfloat );
733740 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, support_simdgroup_mm);
734741 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, support_simdgroup_mm);
735742 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, support_simdgroup_mm);
@@ -814,11 +821,11 @@ @implementation GGMLMetalClass
814821 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, support_simdgroup_reduction);
815822 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true );
816823 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true );
817- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, true );
824+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, support_bfloat );
818825 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true );
819826 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true );
820- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, true );
821- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, true );
827+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, support_bfloat );
828+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, support_bfloat );
822829 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true );
823830 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true );
824831 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true );
@@ -910,6 +917,15 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
910917static bool ggml_metal_supports_op (const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
911918 const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm ;
912919 const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction ;
920+ const bool support_bfloat = ctx_dev->support_bfloat ;
921+
922+ if (!support_bfloat) {
923+ for (size_t i = 0 , n = 3 ; i < n; ++i) {
924+ if (op->src [i] != NULL && op->src [i]->type == GGML_TYPE_BF16) {
925+ return false ;
926+ }
927+ }
928+ }
913929
914930 switch (op->op ) {
915931 case GGML_OP_UNARY:
0 commit comments