Skip to content

Commit c915d0a

Browse files
committed
metal : add mul_mat_id BF16 support
ggml-ci
1 parent 6109cf1 commit c915d0a

File tree

3 files changed

+39
-26
lines changed

3 files changed

+39
-26
lines changed

ggml/src/ggml-metal.m

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
147147
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
148148
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
149149
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
150-
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
151150
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
152151
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
153152
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
153+
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
154154
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
155155
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
156156
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
@@ -175,10 +175,11 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
175175
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
176176
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
177177
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
178-
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
179178
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
180179
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
181180
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
181+
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
182+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32,
182183
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
183184
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
184185
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
@@ -222,6 +223,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
222223
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
223224
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
224225
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
226+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,
225227
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
226228
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
227229
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
@@ -310,6 +312,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
310312
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
311313
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
312314
GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
315+
GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16,
313316
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
314317
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
315318
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
@@ -654,10 +657,10 @@ @implementation GGMLMetalClass
654657
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, support_simdgroup_reduction);
655658
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, support_simdgroup_reduction);
656659
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, support_simdgroup_reduction);
657-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, support_simdgroup_reduction);
658660
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, support_simdgroup_reduction);
659661
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, support_simdgroup_reduction);
660662
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, support_simdgroup_reduction);
663+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, support_simdgroup_reduction);
661664
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, support_simdgroup_reduction);
662665
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, support_simdgroup_reduction);
663666
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, support_simdgroup_reduction);
@@ -678,10 +681,11 @@ @implementation GGMLMetalClass
678681
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, support_simdgroup_reduction);
679682
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, support_simdgroup_reduction);
680683
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, support_simdgroup_reduction);
681-
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, support_simdgroup_reduction);
682684
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, support_simdgroup_reduction);
683685
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, support_simdgroup_reduction);
684686
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, support_simdgroup_reduction);
687+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, support_simdgroup_reduction);
688+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, support_simdgroup_reduction);
685689
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, support_simdgroup_reduction);
686690
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, support_simdgroup_reduction);
687691
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, support_simdgroup_reduction);
@@ -725,6 +729,7 @@ @implementation GGMLMetalClass
725729
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, support_simdgroup_mm);
726730
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, support_simdgroup_mm);
727731
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, support_simdgroup_mm);
732+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, support_simdgroup_mm);
728733
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, support_simdgroup_mm);
729734
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, support_simdgroup_mm);
730735
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, support_simdgroup_mm);
@@ -813,6 +818,7 @@ @implementation GGMLMetalClass
813818
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
814819
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
815820
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, true);
821+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, true);
816822
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
817823
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
818824
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
@@ -902,17 +908,6 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
902908
}
903909

904910
static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
905-
for (size_t i = 0, n = 3; i < n; ++i) {
906-
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16 &&
907-
op->op != GGML_OP_GET_ROWS &&
908-
op->op != GGML_OP_MUL_MAT &&
909-
op->op != GGML_OP_VIEW &&
910-
op->op != GGML_OP_CPY) {
911-
GGML_LOG_ERROR("unsupported BF16 op = %s, src[%zu] = %s\n", ggml_op_name(op->op), i, ggml_type_name(op->src[i]->type));
912-
GGML_ASSERT(false);
913-
}
914-
}
915-
916911
const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm;
917912
const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
918913

@@ -1002,10 +997,16 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1002997
return false;
1003998
}
1004999
case GGML_TYPE_F16:
1005-
case GGML_TYPE_BF16:
10061000
switch (op->type) {
10071001
case GGML_TYPE_F32:
10081002
case GGML_TYPE_F16:
1003+
return true;
1004+
default:
1005+
return false;
1006+
}
1007+
case GGML_TYPE_BF16:
1008+
switch (op->type) {
1009+
case GGML_TYPE_F32:
10091010
case GGML_TYPE_BF16:
10101011
return true;
10111012
default:
@@ -2203,12 +2204,12 @@ static void ggml_metal_encode_node(
22032204
if ([device supportsFamily:MTLGPUFamilyApple7] &&
22042205
ne00 % 32 == 0 && ne00 >= 64 &&
22052206
dst_rows > dst_rows_min) {
2206-
22072207
// some Metal matrix data types require aligned pointers
22082208
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
22092209
switch (src0->type) {
2210-
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
2211-
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
2210+
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
2211+
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
2212+
case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
22122213
default: break;
22132214
}
22142215

@@ -2217,6 +2218,7 @@ static void ggml_metal_encode_node(
22172218
switch (src0->type) {
22182219
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
22192220
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
2221+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break;
22202222
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
22212223
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
22222224
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
@@ -2286,6 +2288,13 @@ static void ggml_metal_encode_node(
22862288
nth1 = 1;
22872289
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
22882290
} break;
2291+
case GGML_TYPE_BF16:
2292+
{
2293+
GGML_ASSERT(src1t == GGML_TYPE_F32);
2294+
nth0 = 32;
2295+
nth1 = 1;
2296+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline;
2297+
} break;
22892298
case GGML_TYPE_Q4_0:
22902299
{
22912300
nth0 = 8;
@@ -3305,6 +3314,7 @@ static void ggml_metal_encode_node(
33053314
{
33063315
switch (dstt) {
33073316
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
3317+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break;
33083318
default: GGML_ASSERT(false && "not implemented");
33093319
};
33103320
} break;

ggml/src/ggml-metal.metal

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3576,12 +3576,13 @@ kernel void kernel_cpy(
35763576

35773577
typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
35783578

3579-
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
3580-
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
3581-
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy<float, bfloat>;
3582-
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
3583-
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
3584-
template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bfloat, float>;
3579+
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
3580+
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
3581+
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy<float, bfloat>;
3582+
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
3583+
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
3584+
template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bfloat, float>;
3585+
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
35853586

35863587
kernel void kernel_cpy_f32_q8_0(
35873588
device const float * src0,
@@ -6547,6 +6548,7 @@ typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t;
65476548

65486549
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
65496550
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
6551+
template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<bfloat4x4, 1, dequantize_bf16>;
65506552
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
65516553
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
65526554
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
@@ -6770,6 +6772,7 @@ typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float
67706772

67716773
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
67726774
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
6775+
template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<bfloat, bfloat4, float, float4>>>;
67736776
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
67746777
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
67756778
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;

tests/test-backend-ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3599,7 +3599,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
35993599
for (int n_mats : {4}) {
36003600
for (int n_used : {2}) {
36013601
for (bool b : {false}) {
3602-
for (int n : {1}) {
3602+
for (int n : {1, 32}) {
36033603
int m = 512;
36043604
int k = 256;
36053605
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k));

0 commit comments

Comments
 (0)