Skip to content

Commit b74aabf

Browse files
committed
metal : opt-in compile flag for BF16
ggml-ci
1 parent 841f27a commit b74aabf

File tree

7 files changed

+66
-31
lines changed

7 files changed

+66
-31
lines changed

.github/workflows/build.yml

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,13 @@ jobs:
5555
sysctl -a
5656
mkdir build
5757
cd build
58-
cmake -DLLAMA_FATAL_WARNINGS=ON -DGGML_METAL_EMBED_LIBRARY=ON -DLLAMA_CURL=ON -DGGML_RPC=ON -DBUILD_SHARED_LIBS=OFF ..
58+
cmake .. \
59+
-DLLAMA_FATAL_WARNINGS=ON \
60+
-DLLAMA_CURL=ON \
61+
-DGGML_METAL_USE_BF16=ON \
62+
-DGGML_METAL_EMBED_LIBRARY=ON \
63+
-DGGML_RPC=ON \
64+
-DBUILD_SHARED_LIBS=OFF
5965
cmake --build . --config Release -j $(sysctl -n hw.logicalcpu)
6066
6167
- name: Test
@@ -113,7 +119,12 @@ jobs:
113119
sysctl -a
114120
# Metal is disabled due to intermittent failures with Github runners not having a GPU:
115121
# https://github.com/ggerganov/llama.cpp/actions/runs/8635935781/job/23674807267#step:5:2313
116-
cmake -B build -DLLAMA_FATAL_WARNINGS=ON -DGGML_METAL=OFF -DLLAMA_CURL=ON -DGGML_RPC=ON -DBUILD_SHARED_LIBS=OFF
122+
cmake -B build \
123+
-DLLAMA_FATAL_WARNINGS=ON \
124+
-DLLAMA_CURL=ON \
125+
-DGGML_METAL=OFF \
126+
-DGGML_RPC=ON \
127+
-DBUILD_SHARED_LIBS=OFF
117128
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
118129
119130
- name: Test
@@ -569,6 +580,7 @@ jobs:
569580
mkdir build
570581
cd build
571582
cmake -G Xcode .. \
583+
-DGGML_METAL_USE_BF16 \
572584
-DGGML_METAL_EMBED_LIBRARY=ON \
573585
-DLLAMA_BUILD_EXAMPLES=OFF \
574586
-DLLAMA_BUILD_TESTS=OFF \
@@ -599,6 +611,7 @@ jobs:
599611
mkdir build
600612
cd build
601613
cmake -G Xcode .. \
614+
-DGGML_METAL_USE_BF16 \
602615
-DGGML_METAL_EMBED_LIBRARY=ON \
603616
-DLLAMA_BUILD_EXAMPLES=OFF \
604617
-DLLAMA_BUILD_TESTS=OFF \

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,10 @@ ifdef GGML_METAL
878878
MK_CPPFLAGS += -DGGML_USE_METAL
879879
MK_LDFLAGS += -framework Foundation -framework Metal -framework MetalKit
880880
OBJ_GGML += ggml/src/ggml-metal.o
881+
882+
ifdef GGML_METAL_USE_BF16
883+
MK_CPPFLAGS += -DGGML_METAL_USE_BF16
884+
endif # GGML_METAL_USE_BF16
881885
ifdef GGML_METAL_NDEBUG
882886
MK_CPPFLAGS += -DGGML_METAL_NDEBUG
883887
endif

Package.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ cSettings.append(
4848
let package = Package(
4949
name: "llama",
5050
platforms: [
51-
.macOS(.v12),
51+
.macOS(.v14),
5252
.iOS(.v14),
5353
.watchOS(.v4),
5454
.tvOS(.v14)

ggml/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation"
153153
option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF)
154154
option(GGML_KOMPUTE "ggml: use Kompute" OFF)
155155
option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
156+
option(GGML_METAL_USE_BF16 "ggml: use bfloat if available" OFF)
156157
option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
157158
option(GGML_METAL_SHADER_DEBUG "ggml: compile Metal with -fno-fast-math" OFF)
158159
option(GGML_METAL_EMBED_LIBRARY "ggml: embed Metal library" ${GGML_METAL})

ggml/src/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ if (GGML_METAL)
5858
add_compile_definitions(GGML_METAL_NDEBUG)
5959
endif()
6060

61+
if (GGML_METAL_USE_BF16)
62+
add_compile_definitions(GGML_METAL_USE_BF16)
63+
endif()
64+
6165
# copy ggml-common.h and ggml-metal.metal to bin directory
6266
configure_file(ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY)
6367
configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)

ggml/src/ggml-metal.m

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
bool has_simdgroup_reduction;
4040
bool has_simdgroup_mm;
4141
bool has_bfloat;
42+
bool use_bfloat;
4243

4344
char name[128];
4445
} g_ggml_ctx_dev_main = {
@@ -47,6 +48,7 @@
4748
/*.has_simdgroup_reduction =*/ false,
4849
/*.has_simdgroup_mm =*/ false,
4950
/*.has_bfloat =*/ false,
51+
/*.use_bfloat =*/ false,
5052
/*.name =*/ "",
5153
};
5254

@@ -65,6 +67,12 @@
6567
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
6668
ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
6769

70+
#if defined(GGML_METAL_USE_BF16)
71+
ctx->use_bfloat = ctx->has_bfloat;
72+
#else
73+
ctx->use_bfloat = false;
74+
#endif
75+
6876
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
6977
}
7078

@@ -504,6 +512,10 @@ @implementation GGMLMetalClass
504512
// dictionary of preprocessor macros
505513
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
506514

515+
if (ctx_dev->use_bfloat) {
516+
[prep setObject:@"1" forKey:@"GGML_METAL_USE_BF16"];
517+
}
518+
507519
MTLCompileOptions * options = [MTLCompileOptions new];
508520
options.preprocessorMacros = prep;
509521

@@ -556,7 +568,8 @@ @implementation GGMLMetalClass
556568

557569
GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, ctx_dev->has_simdgroup_reduction ? "true" : "false");
558570
GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm ? "true" : "false");
559-
GGML_LOG_INFO("%s: bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false");
571+
GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false");
572+
GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false");
560573
GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
561574

562575
ctx->capture_next_compute = false;
@@ -608,7 +621,7 @@ @implementation GGMLMetalClass
608621

609622
const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm;
610623
const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction;
611-
const bool has_bfloat = ctx_dev->has_bfloat;
624+
const bool use_bfloat = ctx_dev->use_bfloat;
612625

613626
// simd_sum and simd_max requires MTLGPUFamilyApple7
614627

@@ -644,7 +657,7 @@ @implementation GGMLMetalClass
644657
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
645658
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
646659
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
647-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, has_bfloat);
660+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, use_bfloat);
648661
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
649662
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
650663
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
@@ -671,10 +684,10 @@ @implementation GGMLMetalClass
671684
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
672685
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
673686
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
674-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && has_bfloat);
675-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && has_bfloat);
676-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && has_bfloat);
677-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && has_bfloat);
687+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
688+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
689+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
690+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
678691
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
679692
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
680693
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
@@ -703,7 +716,7 @@ @implementation GGMLMetalClass
703716
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, has_simdgroup_reduction);
704717
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, has_simdgroup_reduction);
705718
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, has_simdgroup_reduction);
706-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && has_bfloat);
719+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && use_bfloat);
707720
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, has_simdgroup_reduction);
708721
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, has_simdgroup_reduction);
709722
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
@@ -725,7 +738,7 @@ @implementation GGMLMetalClass
725738
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction);
726739
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm);
727740
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm);
728-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && has_bfloat);
741+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat);
729742
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, has_simdgroup_mm);
730743
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, has_simdgroup_mm);
731744
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
@@ -747,7 +760,7 @@ @implementation GGMLMetalClass
747760
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
748761
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm);
749762
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm);
750-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && has_bfloat);
763+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && use_bfloat);
751764
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm);
752765
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm);
753766
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm);
@@ -840,11 +853,11 @@ @implementation GGMLMetalClass
840853
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
841854
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
842855
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
843-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, has_bfloat);
856+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat);
844857
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
845858
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
846-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, has_bfloat);
847-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, has_bfloat);
859+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat);
860+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat);
848861
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
849862
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
850863
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
@@ -936,9 +949,9 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
936949
static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
937950
const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm;
938951
const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction;
939-
const bool has_bfloat = ctx_dev->has_bfloat;
952+
const bool use_bfloat = ctx_dev->use_bfloat;
940953

941-
if (!has_bfloat) {
954+
if (!use_bfloat) {
942955
for (size_t i = 0, n = 3; i < n; ++i) {
943956
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
944957
return false;

0 commit comments

Comments
 (0)