Skip to content

Commit ad12269

Browse files
committed
metal : do not build bfloat kernels when not supported
ggml-ci
1 parent a408f51 commit ad12269

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

ggml/src/ggml-metal.m

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,11 @@ @implementation GGMLMetalClass
496496
// dictionary of preprocessor macros
497497
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
498498

499-
MTLCompileOptions* options = [MTLCompileOptions new];
499+
if (!ctx_dev->has_bfloat) {
500+
[prep setObject:@"GGML_METAL_NO_BFLOAT" forKey:@"GGML_METAL_NO_BFLOAT"];
501+
}
502+
503+
MTLCompileOptions * options = [MTLCompileOptions new];
500504
options.preprocessorMacros = prep;
501505

502506
//[options setFastMathEnabled:false];

ggml/src/ggml-metal.metal

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@ using namespace metal;
1212

1313
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
1414

15+
#if !defined(GGML_METAL_NO_BFLOAT)
16+
typedef matrix<bfloat, 4, 4> bfloat4x4;
17+
#endif
18+
1519
constexpr constant static float kvalues_iq4nl_f[16] = {
1620
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
1721
};
1822

19-
typedef matrix<bfloat, 4, 4> bfloat4x4;
20-
2123
// NOTE: this is not dequantizing - we are simply fitting the template
2224
template <typename type4x4>
2325
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
@@ -29,10 +31,12 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
2931
reg = (type4x4)(*src);
3032
}
3133

34+
#if !defined(GGML_METAL_NO_BFLOAT)
3235
template <typename type4x4>
3336
void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
3437
reg = (type4x4)(*src);
3538
}
39+
#endif
3640

3741
template <typename type4x4>
3842
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
@@ -2048,8 +2052,10 @@ typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t;
20482052
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv<float, float4, float, float4>;
20492053
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv<half, half4, float, float4>;
20502054
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv<half, half4, half, half4>;
2055+
#if !defined(GGML_METAL_NO_BFLOAT)
20512056
template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, float, float4>;
20522057
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
2058+
#endif
20532059

20542060
template<typename T, typename T4>
20552061
kernel void kernel_mul_mv_1row(
@@ -2119,7 +2125,9 @@ kernel void kernel_mul_mv_1row(
21192125
typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;
21202126

21212127
template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<half, half4>;
2128+
#if !defined(GGML_METAL_NO_BFLOAT)
21222129
template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<bfloat, bfloat4>;
2130+
#endif
21232131

21242132
// Assumes row size (ne00) is a multiple of 4
21252133
template<typename T, typename T4>
@@ -2179,7 +2187,9 @@ kernel void kernel_mul_mv_l4(
21792187
typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;
21802188

21812189
template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
2190+
#if !defined(GGML_METAL_NO_BFLOAT)
21822191
template [[host_name("kernel_mul_mv_bf16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<bfloat, bfloat4>;
2192+
#endif
21832193

21842194
static float rope_yarn_ramp(const float low, const float high, const int i0) {
21852195
const float y = (i0 / 2 - low) / max(0.001f, high - low);
@@ -3578,11 +3588,15 @@ typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
35783588

35793589
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
35803590
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
3591+
#if !defined(GGML_METAL_NO_BFLOAT)
35813592
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy<float, bfloat>;
3593+
#endif
35823594
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
35833595
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
3596+
#if !defined(GGML_METAL_NO_BFLOAT)
35843597
template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bfloat, float>;
35853598
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
3599+
#endif
35863600

35873601
kernel void kernel_cpy_f32_q8_0(
35883602
device const float * src0,
@@ -6487,7 +6501,9 @@ typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;
64876501

64886502
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float>;
64896503
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half>;
6504+
#if !defined(GGML_METAL_NO_BFLOAT)
64906505
template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat>;
6506+
#endif
64916507

64926508
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
64936509

@@ -6519,7 +6535,9 @@ typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, de
65196535

65206536
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
65216537
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
6538+
#if !defined(GGML_METAL_NO_BFLOAT)
65226539
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
6540+
#endif
65236541
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
65246542
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
65256543
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
@@ -6548,7 +6566,9 @@ typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t;
65486566

65496567
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
65506568
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
6569+
#if !defined(GGML_METAL_NO_BFLOAT)
65516570
template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<bfloat4x4, 1, dequantize_bf16>;
6571+
#endif
65526572
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
65536573
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
65546574
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
@@ -6772,7 +6792,9 @@ typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float
67726792

67736793
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
67746794
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
6795+
#if !defined(GGML_METAL_NO_BFLOAT)
67756796
template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<bfloat, bfloat4, float, float4>>>;
6797+
#endif
67766798
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
67776799
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
67786800
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;

0 commit comments

Comments
 (0)