@@ -381,6 +381,7 @@ - (void) dealloc {
381381 // additional, inference-time compiled kernels
382382 NSMutableDictionary * kernels_ext;
383383
384+ bool use_bfloat;
384385 bool use_fusion;
385386 bool use_concurrency;
386387 bool use_graph_optimize;
@@ -487,6 +488,7 @@ @implementation GGMLMetalClass
487488
488489 ctx->d_queue = dispatch_queue_create (" ggml-metal" , DISPATCH_QUEUE_CONCURRENT);
489490
491+ ctx->use_bfloat = ctx->props_dev .has_bfloat ;
490492 ctx->use_fusion = getenv (" GGML_METAL_FUSION_DISABLE" ) == nil ;
491493 ctx->use_concurrency = getenv (" GGML_METAL_CONCURRENCY_DISABLE" ) == nil ;
492494
@@ -508,6 +510,7 @@ @implementation GGMLMetalClass
508510
509511 memset (ctx->fuse_cnt , 0 , sizeof (ctx->fuse_cnt ));
510512
513+ GGML_LOG_INFO (" %s : use bfloat = %s \n " , __func__, ctx->use_bfloat ? " true" : " false" );
511514 GGML_LOG_INFO (" %s : use fusion = %s \n " , __func__, ctx->use_fusion ? " true" : " false" );
512515 GGML_LOG_INFO (" %s : use concurrency = %s \n " , __func__, ctx->use_concurrency ? " true" : " false" );
513516 GGML_LOG_INFO (" %s : use graph optimize = %s \n " , __func__, ctx->use_graph_optimize ? " true" : " false" );
@@ -557,7 +560,7 @@ @implementation GGMLMetalClass
557560
558561 const bool has_simdgroup_mm = ctx->props_dev .has_simdgroup_mm ;
559562 const bool has_simdgroup_reduction = ctx->props_dev .has_simdgroup_reduction ;
560- const bool use_bfloat = ctx->props_dev .use_bfloat ;
563+ const bool has_bfloat = ctx->props_dev .has_bfloat ;
561564
562565 // simd_sum and simd_max requires MTLGPUFamilyApple7
563566
@@ -595,7 +598,7 @@ @implementation GGMLMetalClass
595598 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true );
596599 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true );
597600 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true );
598- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, use_bfloat );
601+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, has_bfloat );
599602 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true );
600603 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true );
601604 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true );
@@ -619,7 +622,7 @@ @implementation GGMLMetalClass
619622 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true );
620623 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, set_rows_f32, true );
621624 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, set_rows_f16, true );
622- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, set_rows_bf16, use_bfloat );
625+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, set_rows_bf16, has_bfloat );
623626 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0, set_rows_q8_0, true );
624627 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0, set_rows_q4_0, true );
625628 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1, set_rows_q4_1, true );
@@ -636,11 +639,11 @@ @implementation GGMLMetalClass
636639 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true );
637640 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
638641 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4, mul_mv_f32_f32_c4, true );
639- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat );
640- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4, mul_mv_bf16_f32_c4, use_bfloat );
641- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat );
642- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat );
643- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat );
642+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && has_bfloat );
643+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4, mul_mv_bf16_f32_c4, has_bfloat );
644+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && has_bfloat );
645+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && has_bfloat );
646+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && has_bfloat );
644647 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
645648 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4, mul_mv_f16_f32_c4, true );
646649 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
@@ -719,7 +722,7 @@ @implementation GGMLMetalClass
719722 // GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, has_simdgroup_reduction);
720723 // GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, has_simdgroup_reduction);
721724 // GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, has_simdgroup_reduction);
722- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && use_bfloat );
725+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && has_bfloat );
723726 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, has_simdgroup_reduction);
724727 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, has_simdgroup_reduction);
725728 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
@@ -742,7 +745,7 @@ @implementation GGMLMetalClass
742745 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction);
743746 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm);
744747 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm);
745- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat );
748+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && has_bfloat );
746749 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, has_simdgroup_mm);
747750 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, has_simdgroup_mm);
748751 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
@@ -772,7 +775,7 @@ @implementation GGMLMetalClass
772775 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16, mul_mm_id_map0_f16_ne20_16, has_simdgroup_mm);
773776 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
774777 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
775- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat );
778+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && has_bfloat );
776779 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16, mul_mm_id_q4_0_f16, has_simdgroup_mm);
777780 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16, mul_mm_id_q4_1_f16, has_simdgroup_mm);
778781 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, mul_mm_id_q5_0_f16, has_simdgroup_mm);
@@ -817,11 +820,11 @@ @implementation GGMLMetalClass
817820 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true );
818821 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true );
819822 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true );
820- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat );
823+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, has_bfloat );
821824 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true );
822825 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true );
823- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat );
824- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat );
826+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, has_bfloat );
827+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, has_bfloat );
825828 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_I32, cpy_f32_i32, true );
826829 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_I32_F32, cpy_i32_f32, true );
827830 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true );
@@ -1348,9 +1351,9 @@ static void ggml_backend_metal_buffer_rset_free(struct ggml_backend_metal_buffer
13481351static bool ggml_metal_supports_op (const struct ggml_backend_metal_device_props * props_dev, const struct ggml_tensor * op) {
13491352 const bool has_simdgroup_mm = props_dev->has_simdgroup_mm ;
13501353 const bool has_simdgroup_reduction = props_dev->has_simdgroup_reduction ;
1351- const bool use_bfloat = props_dev->use_bfloat ;
1354+ const bool has_bfloat = props_dev->has_bfloat ;
13521355
1353- if (!use_bfloat ) {
1356+ if (!has_bfloat ) {
13541357 if (op->type == GGML_TYPE_BF16) {
13551358 return false ;
13561359 }
@@ -6088,9 +6091,9 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
60886091
60896092 const int n_nodes_per_cb = ctx->n_nodes_per_cb ;
60906093
6091- id <MTLCommandBuffer > cmd_buf = ctx->cmd_bufs [cb_idx].obj ;
6092- struct ggml_mem_ranges * mem_ranges = ctx->cmd_bufs [cb_idx].mem_ranges ;
6094+ id <MTLCommandBuffer > cmd_buf = ctx->cmd_bufs [cb_idx].obj ;
60936095
6096+ struct ggml_mem_ranges * mem_ranges = ctx->cmd_bufs [cb_idx].mem_ranges ;
60946097 if (mem_ranges) {
60956098 ggml_mem_ranges_reset (mem_ranges);
60966099 }
@@ -6467,9 +6470,6 @@ static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t r
64676470static struct ggml_backend_feature g_ggml_backend_metal_features[] = {
64686471#if defined(GGML_METAL_EMBED_LIBRARY)
64696472 { " EMBED_LIBRARY" , " 1" },
6470- #endif
6471- #if defined(GGML_METAL_USE_BF16)
6472- { " BF16" , " 1" },
64736473#endif
64746474 { nil , nil },
64756475};
0 commit comments