@@ -12,12 +12,14 @@ using namespace metal;
1212
1313#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
1414
15+ #if !defined(GGML_METAL_NO_BFLOAT)
16+ typedef matrix<bfloat, 4 , 4 > bfloat4x4;
17+ #endif
18+
1519constexpr constant static float kvalues_iq4nl_f[16 ] = {
1620 -127 .f , -104 .f , -83 .f , -65 .f , -49 .f , -35 .f , -22 .f , -10 .f , 1 .f , 13 .f , 25 .f , 38 .f , 53 .f , 69 .f , 89 .f , 113 .f
1721};
1822
19- typedef matrix<bfloat, 4 , 4 > bfloat4x4;
20-
2123// NOTE: this is not dequantizing - we are simply fitting the template
2224template <typename type4x4>
2325void dequantize_f32 (device const float4x4 * src, short il, thread type4x4 & reg) {
@@ -29,10 +31,12 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
2931 reg = (type4x4)(*src);
3032}
3133
34+ #if !defined(GGML_METAL_NO_BFLOAT)
3235template <typename type4x4>
3336void dequantize_bf16 (device const bfloat4x4 * src, short il, thread type4x4 & reg) {
3437 reg = (type4x4)(*src);
3538}
39+ #endif
3640
3741template <typename type4x4>
3842void dequantize_q4_0 (device const block_q4_0 *xb, short il, thread type4x4 & reg) {
@@ -2048,8 +2052,10 @@ typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t;
20482052template [[host_name(" kernel_mul_mv_f32_f32" )]] kernel mul_mv_t kernel_mul_mv<float , float4, float , float4>;
20492053template [[host_name(" kernel_mul_mv_f16_f32" )]] kernel mul_mv_t kernel_mul_mv<half, half4, float , float4>;
20502054template [[host_name(" kernel_mul_mv_f16_f16" )]] kernel mul_mv_t kernel_mul_mv<half, half4, half, half4>;
2055+ #if !defined(GGML_METAL_NO_BFLOAT)
20512056template [[host_name(" kernel_mul_mv_bf16_f32" )]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, float , float4>;
20522057template [[host_name(" kernel_mul_mv_bf16_bf16" )]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
2058+ #endif
20532059
20542060template <typename T, typename T4>
20552061kernel void kernel_mul_mv_1row (
@@ -2119,7 +2125,9 @@ kernel void kernel_mul_mv_1row(
21192125typedef decltype (kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;
21202126
21212127template [[host_name(" kernel_mul_mv_f16_f32_1row" )]] kernel mul_mv_1row_t kernel_mul_mv_1row<half, half4>;
2128+ #if !defined(GGML_METAL_NO_BFLOAT)
21222129template [[host_name(" kernel_mul_mv_bf16_f32_1row" )]] kernel mul_mv_1row_t kernel_mul_mv_1row<bfloat, bfloat4>;
2130+ #endif
21232131
21242132// Assumes row size (ne00) is a multiple of 4
21252133template <typename T, typename T4>
@@ -2179,7 +2187,9 @@ kernel void kernel_mul_mv_l4(
21792187typedef decltype (kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;
21802188
21812189template [[host_name(" kernel_mul_mv_f16_f32_l4" )]] kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
2190+ #if !defined(GGML_METAL_NO_BFLOAT)
21822191template [[host_name(" kernel_mul_mv_bf16_f32_l4" )]] kernel mul_mv_l4_t kernel_mul_mv_l4<bfloat, bfloat4>;
2192+ #endif
21832193
21842194static float rope_yarn_ramp (const float low, const float high, const int i0) {
21852195 const float y = (i0 / 2 - low) / max (0 .001f , high - low);
@@ -3578,11 +3588,15 @@ typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
35783588
35793589template [[host_name(" kernel_cpy_f32_f32" )]] kernel kernel_cpy_t kernel_cpy<float , float >;
35803590template [[host_name(" kernel_cpy_f32_f16" )]] kernel kernel_cpy_t kernel_cpy<float , half>;
3591+ #if !defined(GGML_METAL_NO_BFLOAT)
35813592template [[host_name(" kernel_cpy_f32_bf16" )]] kernel kernel_cpy_t kernel_cpy<float , bfloat>;
3593+ #endif
35823594template [[host_name(" kernel_cpy_f16_f32" )]] kernel kernel_cpy_t kernel_cpy<half, float >;
35833595template [[host_name(" kernel_cpy_f16_f16" )]] kernel kernel_cpy_t kernel_cpy<half, half>;
3596+ #if !defined(GGML_METAL_NO_BFLOAT)
35843597template [[host_name(" kernel_cpy_bf16_f32" )]] kernel kernel_cpy_t kernel_cpy<bfloat, float >;
35853598template [[host_name(" kernel_cpy_bf16_bf16" )]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
3599+ #endif
35863600
35873601kernel void kernel_cpy_f32_q8_0 (
35883602 device const float * src0,
@@ -6487,7 +6501,9 @@ typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;
64876501
64886502template [[host_name(" kernel_get_rows_f32" )]] kernel get_rows_f_t kernel_get_rows_f<float >;
64896503template [[host_name(" kernel_get_rows_f16" )]] kernel get_rows_f_t kernel_get_rows_f<half>;
6504+ #if !defined(GGML_METAL_NO_BFLOAT)
64906505template [[host_name(" kernel_get_rows_bf16" )]] kernel get_rows_f_t kernel_get_rows_f<bfloat>;
6506+ #endif
64916507
64926508typedef decltype (kernel_get_rows_q<block_q4_0, 2 , dequantize_q4_0>) get_rows_q_t;
64936509
@@ -6519,7 +6535,9 @@ typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, de
65196535
65206536template [[host_name(" kernel_mul_mm_f32_f32" )]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1 , dequantize_f32>;
65216537template [[host_name(" kernel_mul_mm_f16_f32" )]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1 , dequantize_f16>;
6538+ #if !defined(GGML_METAL_NO_BFLOAT)
65226539template [[host_name(" kernel_mul_mm_bf16_f32" )]] kernel mat_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1 , dequantize_bf16>;
6540+ #endif
65236541template [[host_name(" kernel_mul_mm_q4_0_f32" )]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2 , dequantize_q4_0>;
65246542template [[host_name(" kernel_mul_mm_q4_1_f32" )]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2 , dequantize_q4_1>;
65256543template [[host_name(" kernel_mul_mm_q5_0_f32" )]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2 , dequantize_q5_0>;
@@ -6548,7 +6566,9 @@ typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t;
65486566
65496567template [[host_name(" kernel_mul_mm_id_f32_f32" )]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1 , dequantize_f32>;
65506568template [[host_name(" kernel_mul_mm_id_f16_f32" )]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1 , dequantize_f16>;
6569+ #if !defined(GGML_METAL_NO_BFLOAT)
65516570template [[host_name(" kernel_mul_mm_id_bf16_f32" )]] kernel mat_mm_id_t kernel_mul_mm_id<bfloat4x4, 1 , dequantize_bf16>;
6571+ #endif
65526572template [[host_name(" kernel_mul_mm_id_q4_0_f32" )]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2 , dequantize_q4_0>;
65536573template [[host_name(" kernel_mul_mm_id_q4_1_f32" )]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2 , dequantize_q4_1>;
65546574template [[host_name(" kernel_mul_mm_id_q5_0_f32" )]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2 , dequantize_q5_0>;
@@ -6772,7 +6792,9 @@ typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float
67726792
67736793template [[host_name(" kernel_mul_mv_id_f32_f32" )]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float , float4, float , float4>>>;
67746794template [[host_name(" kernel_mul_mv_id_f16_f32" )]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float , float4>>>;
6795+ #if !defined(GGML_METAL_NO_BFLOAT)
67756796template [[host_name(" kernel_mul_mv_id_bf16_f32" )]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<bfloat, bfloat4, float , float4>>>;
6797+ #endif
67766798template [[host_name(" kernel_mul_mv_id_q8_0_f32" )]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
67776799template [[host_name(" kernel_mul_mv_id_q4_0_f32" )]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
67786800template [[host_name(" kernel_mul_mv_id_q4_1_f32" )]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
0 commit comments