Skip to content

Commit 3ee077a

Browse files
committed
metal : check for bfloat support on the Metal device
ggml-ci
1 parent c915d0a commit 3ee077a

File tree

1 file changed

+30
-14
lines changed

1 file changed

+30
-14
lines changed

ggml/src/ggml-metal.m

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,15 @@
3838

3939
bool support_simdgroup_reduction;
4040
bool support_simdgroup_mm;
41+
bool support_bfloat;
4142

4243
char name[128];
4344
} g_ggml_ctx_dev_main = {
4445
/*.mtl_device =*/ nil,
4546
/*.mtl_device_ref_count =*/ 0,
4647
/*.support_simdgroup_reduction =*/ false,
4748
/*.support_simdgroup_mm =*/ false,
49+
/*.support_bfloat =*/ false,
4850
/*.name =*/ "",
4951
};
5052

@@ -60,6 +62,9 @@
6062

6163
ctx->support_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
6264

65+
ctx->support_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
66+
ctx->support_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
67+
6368
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
6469
}
6570

@@ -541,9 +546,10 @@ @implementation GGMLMetalClass
541546
}
542547
}
543548

544-
GGML_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx_dev->support_simdgroup_reduction ? "true" : "false");
545-
GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx_dev->support_simdgroup_mm ? "true" : "false");
546-
GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
549+
GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, ctx_dev->support_simdgroup_reduction ? "true" : "false");
550+
GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->support_simdgroup_mm ? "true" : "false");
551+
GGML_LOG_INFO("%s: bfloat = %s\n", __func__, ctx_dev->support_bfloat ? "true" : "false");
552+
GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
547553

548554
ctx->capture_next_compute = false;
549555
ctx->capture_started = false;
@@ -591,6 +597,7 @@ @implementation GGMLMetalClass
591597

592598
const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm;
593599
const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
600+
const bool support_bfloat = ctx_dev->support_bfloat;
594601

595602
// simd_sum and simd_max requires MTLGPUFamilyApple7
596603

@@ -626,7 +633,7 @@ @implementation GGMLMetalClass
626633
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
627634
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
628635
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
629-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, true);
636+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, support_bfloat);
630637
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
631638
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
632639
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
@@ -653,10 +660,10 @@ @implementation GGMLMetalClass
653660
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
654661
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
655662
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, support_simdgroup_reduction);
656-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, support_simdgroup_reduction);
657-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, support_simdgroup_reduction);
658-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, support_simdgroup_reduction);
659-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, support_simdgroup_reduction);
663+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, support_simdgroup_reduction && support_bfloat);
664+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, support_simdgroup_reduction && support_bfloat);
665+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, support_simdgroup_reduction && support_bfloat);
666+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, support_simdgroup_reduction && support_bfloat);
660667
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, support_simdgroup_reduction);
661668
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, support_simdgroup_reduction);
662669
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, support_simdgroup_reduction);
@@ -685,7 +692,7 @@ @implementation GGMLMetalClass
685692
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, support_simdgroup_reduction);
686693
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, support_simdgroup_reduction);
687694
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, support_simdgroup_reduction);
688-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, support_simdgroup_reduction);
695+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, support_simdgroup_reduction && support_bfloat);
689696
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, support_simdgroup_reduction);
690697
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, support_simdgroup_reduction);
691698
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, support_simdgroup_reduction);
@@ -707,7 +714,7 @@ @implementation GGMLMetalClass
707714
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, support_simdgroup_reduction);
708715
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, support_simdgroup_mm);
709716
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, support_simdgroup_mm);
710-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, support_simdgroup_mm);
717+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, support_simdgroup_mm && support_bfloat);
711718
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, support_simdgroup_mm);
712719
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, support_simdgroup_mm);
713720
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, support_simdgroup_mm);
@@ -729,7 +736,7 @@ @implementation GGMLMetalClass
729736
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, support_simdgroup_mm);
730737
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, support_simdgroup_mm);
731738
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, support_simdgroup_mm);
732-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, support_simdgroup_mm);
739+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, support_simdgroup_mm && support_bfloat);
733740
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, support_simdgroup_mm);
734741
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, support_simdgroup_mm);
735742
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, support_simdgroup_mm);
@@ -814,11 +821,11 @@ @implementation GGMLMetalClass
814821
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, support_simdgroup_reduction);
815822
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
816823
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
817-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, true);
824+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, support_bfloat);
818825
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
819826
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
820-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, true);
821-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, true);
827+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, support_bfloat);
828+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, support_bfloat);
822829
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
823830
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
824831
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
@@ -910,6 +917,15 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
910917
static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
911918
const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm;
912919
const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
920+
const bool support_bfloat = ctx_dev->support_bfloat;
921+
922+
if (!support_bfloat) {
923+
for (size_t i = 0, n = 3; i < n; ++i) {
924+
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
925+
return false;
926+
}
927+
}
928+
}
913929

914930
switch (op->op) {
915931
case GGML_OP_UNARY:

0 commit comments

Comments
 (0)