Skip to content

Commit c2d312b

Browse files
committed
metal : opt-in compile flag for BF16
ggml-ci
1 parent d05b312 commit c2d312b

File tree

7 files changed

+66
-31
lines changed

7 files changed

+66
-31
lines changed

.github/workflows/build.yml

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,13 @@ jobs:
5555
sysctl -a
5656
mkdir build
5757
cd build
58-
cmake -DLLAMA_FATAL_WARNINGS=ON -DGGML_METAL_EMBED_LIBRARY=ON -DLLAMA_CURL=ON -DGGML_RPC=ON -DBUILD_SHARED_LIBS=OFF ..
58+
cmake .. \
59+
-DLLAMA_FATAL_WARNINGS=ON \
60+
-DLLAMA_CURL=ON \
61+
-DGGML_METAL_USE_BF16=ON \
62+
-DGGML_METAL_EMBED_LIBRARY=ON \
63+
-DGGML_RPC=ON \
64+
-DBUILD_SHARED_LIBS=OFF
5965
cmake --build . --config Release -j $(sysctl -n hw.logicalcpu)
6066
6167
- name: Test
@@ -113,7 +119,12 @@ jobs:
113119
sysctl -a
114120
# Metal is disabled due to intermittent failures with Github runners not having a GPU:
115121
# https://github.com/ggerganov/llama.cpp/actions/runs/8635935781/job/23674807267#step:5:2313
116-
cmake -B build -DLLAMA_FATAL_WARNINGS=ON -DGGML_METAL=OFF -DLLAMA_CURL=ON -DGGML_RPC=ON -DBUILD_SHARED_LIBS=OFF
122+
cmake -B build \
123+
-DLLAMA_FATAL_WARNINGS=ON \
124+
-DLLAMA_CURL=ON \
125+
-DGGML_METAL=OFF \
126+
-DGGML_RPC=ON \
127+
-DBUILD_SHARED_LIBS=OFF
117128
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
118129
119130
- name: Test
@@ -569,6 +580,7 @@ jobs:
569580
mkdir build
570581
cd build
571582
cmake -G Xcode .. \
583+
-DGGML_METAL_USE_BF16 \
572584
-DGGML_METAL_EMBED_LIBRARY=ON \
573585
-DLLAMA_BUILD_EXAMPLES=OFF \
574586
-DLLAMA_BUILD_TESTS=OFF \
@@ -599,6 +611,7 @@ jobs:
599611
mkdir build
600612
cd build
601613
cmake -G Xcode .. \
614+
-DGGML_METAL_USE_BF16 \
602615
-DGGML_METAL_EMBED_LIBRARY=ON \
603616
-DLLAMA_BUILD_EXAMPLES=OFF \
604617
-DLLAMA_BUILD_TESTS=OFF \

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,10 @@ ifdef GGML_METAL
878878
MK_CPPFLAGS += -DGGML_USE_METAL
879879
MK_LDFLAGS += -framework Foundation -framework Metal -framework MetalKit
880880
OBJ_GGML += ggml/src/ggml-metal.o
881+
882+
ifdef GGML_METAL_USE_BF16
883+
MK_CPPFLAGS += -DGGML_METAL_USE_BF16
884+
endif # GGML_METAL_USE_BF16
881885
ifdef GGML_METAL_NDEBUG
882886
MK_CPPFLAGS += -DGGML_METAL_NDEBUG
883887
endif

Package.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ cSettings.append(
4848
let package = Package(
4949
name: "llama",
5050
platforms: [
51-
.macOS(.v12),
51+
.macOS(.v14),
5252
.iOS(.v14),
5353
.watchOS(.v4),
5454
.tvOS(.v14)

ggml/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation"
153153
option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF)
154154
option(GGML_KOMPUTE "ggml: use Kompute" OFF)
155155
option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
156+
option(GGML_METAL_USE_BF16 "ggml: use bfloat if available" OFF)
156157
option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
157158
option(GGML_METAL_SHADER_DEBUG "ggml: compile Metal with -fno-fast-math" OFF)
158159
option(GGML_METAL_EMBED_LIBRARY "ggml: embed Metal library" ${GGML_METAL})

ggml/src/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ if (GGML_METAL)
5858
add_compile_definitions(GGML_METAL_NDEBUG)
5959
endif()
6060

61+
if (GGML_METAL_USE_BF16)
62+
add_compile_definitions(GGML_METAL_USE_BF16)
63+
endif()
64+
6165
# copy ggml-common.h and ggml-metal.metal to bin directory
6266
configure_file(ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY)
6367
configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)

ggml/src/ggml-metal.m

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
bool has_simdgroup_reduction;
4040
bool has_simdgroup_mm;
4141
bool has_bfloat;
42+
bool use_bfloat;
4243

4344
char name[128];
4445
} g_ggml_ctx_dev_main = {
@@ -47,6 +48,7 @@
4748
/*.has_simdgroup_reduction =*/ false,
4849
/*.has_simdgroup_mm =*/ false,
4950
/*.has_bfloat =*/ false,
51+
/*.use_bfloat =*/ false,
5052
/*.name =*/ "",
5153
};
5254

@@ -65,6 +67,12 @@
6567
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
6668
ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
6769

70+
#if defined(GGML_METAL_USE_BF16)
71+
ctx->use_bfloat = ctx->has_bfloat;
72+
#else
73+
ctx->use_bfloat = false;
74+
#endif
75+
6876
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
6977
}
7078

@@ -496,6 +504,10 @@ @implementation GGMLMetalClass
496504
// dictionary of preprocessor macros
497505
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
498506

507+
if (ctx_dev->use_bfloat) {
508+
[prep setObject:@"1" forKey:@"GGML_METAL_USE_BF16"];
509+
}
510+
499511
MTLCompileOptions * options = [MTLCompileOptions new];
500512
options.preprocessorMacros = prep;
501513

@@ -548,7 +560,8 @@ @implementation GGMLMetalClass
548560

549561
GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, ctx_dev->has_simdgroup_reduction ? "true" : "false");
550562
GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm ? "true" : "false");
551-
GGML_LOG_INFO("%s: bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false");
563+
GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false");
564+
GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false");
552565
GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
553566

554567
ctx->capture_next_compute = false;
@@ -597,7 +610,7 @@ @implementation GGMLMetalClass
597610

598611
const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm;
599612
const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction;
600-
const bool has_bfloat = ctx_dev->has_bfloat;
613+
const bool use_bfloat = ctx_dev->use_bfloat;
601614

602615
// simd_sum and simd_max requires MTLGPUFamilyApple7
603616

@@ -633,7 +646,7 @@ @implementation GGMLMetalClass
633646
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
634647
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
635648
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
636-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, has_bfloat);
649+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, use_bfloat);
637650
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
638651
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
639652
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
@@ -660,10 +673,10 @@ @implementation GGMLMetalClass
660673
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
661674
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
662675
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
663-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && has_bfloat);
664-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && has_bfloat);
665-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && has_bfloat);
666-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && has_bfloat);
676+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
677+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
678+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
679+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
667680
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
668681
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
669682
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
@@ -692,7 +705,7 @@ @implementation GGMLMetalClass
692705
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, has_simdgroup_reduction);
693706
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, has_simdgroup_reduction);
694707
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, has_simdgroup_reduction);
695-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && has_bfloat);
708+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && use_bfloat);
696709
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, has_simdgroup_reduction);
697710
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, has_simdgroup_reduction);
698711
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
@@ -714,7 +727,7 @@ @implementation GGMLMetalClass
714727
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction);
715728
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm);
716729
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm);
717-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && has_bfloat);
730+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat);
718731
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, has_simdgroup_mm);
719732
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, has_simdgroup_mm);
720733
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
@@ -736,7 +749,7 @@ @implementation GGMLMetalClass
736749
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
737750
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm);
738751
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm);
739-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && has_bfloat);
752+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && use_bfloat);
740753
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm);
741754
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm);
742755
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm);
@@ -821,11 +834,11 @@ @implementation GGMLMetalClass
821834
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
822835
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
823836
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
824-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, has_bfloat);
837+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat);
825838
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
826839
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
827-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, has_bfloat);
828-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, has_bfloat);
840+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat);
841+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat);
829842
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
830843
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
831844
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
@@ -917,9 +930,9 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
917930
static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
918931
const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm;
919932
const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction;
920-
const bool has_bfloat = ctx_dev->has_bfloat;
933+
const bool use_bfloat = ctx_dev->use_bfloat;
921934

922-
if (!has_bfloat) {
935+
if (!use_bfloat) {
923936
for (size_t i = 0, n = 3; i < n; ++i) {
924937
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
925938
return false;

0 commit comments

Comments
 (0)