@@ -147,10 +147,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
147147 GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
148148 GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
149149 GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
150- GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
151150 GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
152151 GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
153152 GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
153+ GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
154154 GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
155155 GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
156156 GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
@@ -175,10 +175,11 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
175175 GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
176176 GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
177177 GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
178- // GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
179178 GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
180179 // GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
181180 // GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
181+ // GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
182+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32,
182183 GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
183184 GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
184185 GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
@@ -222,6 +223,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
222223 GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
223224 GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
224225 GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
226+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,
225227 GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
226228 GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
227229 GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
@@ -310,6 +312,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
310312 GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
311313 GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
312314 GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
315+ GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16,
313316 GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
314317 GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
315318 GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
@@ -654,10 +657,10 @@ @implementation GGMLMetalClass
654657 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, support_simdgroup_reduction);
655658 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, support_simdgroup_reduction);
656659 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, support_simdgroup_reduction);
657- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, support_simdgroup_reduction);
658660 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, support_simdgroup_reduction);
659661 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, support_simdgroup_reduction);
660662 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, support_simdgroup_reduction);
663+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, support_simdgroup_reduction);
661664 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, support_simdgroup_reduction);
662665 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, support_simdgroup_reduction);
663666 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, support_simdgroup_reduction);
@@ -678,10 +681,11 @@ @implementation GGMLMetalClass
678681 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, support_simdgroup_reduction);
679682 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, support_simdgroup_reduction);
680683 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, support_simdgroup_reduction);
681- // GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, support_simdgroup_reduction);
682684 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, support_simdgroup_reduction);
683685 // GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, support_simdgroup_reduction);
684686 // GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, support_simdgroup_reduction);
687+ // GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, support_simdgroup_reduction);
688+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, support_simdgroup_reduction);
685689 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, support_simdgroup_reduction);
686690 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, support_simdgroup_reduction);
687691 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, support_simdgroup_reduction);
@@ -725,6 +729,7 @@ @implementation GGMLMetalClass
725729 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, support_simdgroup_mm);
726730 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, support_simdgroup_mm);
727731 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, support_simdgroup_mm);
732+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, support_simdgroup_mm);
728733 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, support_simdgroup_mm);
729734 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, support_simdgroup_mm);
730735 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, support_simdgroup_mm);
@@ -813,6 +818,7 @@ @implementation GGMLMetalClass
813818 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true );
814819 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true );
815820 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, true );
821+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, true );
816822 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true );
817823 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true );
818824 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true );
@@ -902,17 +908,6 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
902908}
903909
904910static bool ggml_metal_supports_op (const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
905- for (size_t i = 0 , n = 3 ; i < n; ++i) {
906- if (op->src [i] != NULL && op->src [i]->type == GGML_TYPE_BF16 &&
907- op->op != GGML_OP_GET_ROWS &&
908- op->op != GGML_OP_MUL_MAT &&
909- op->op != GGML_OP_VIEW &&
910- op->op != GGML_OP_CPY) {
911- GGML_LOG_ERROR (" unsupported BF16 op = %s , src[%zu ] = %s \n " , ggml_op_name (op->op ), i, ggml_type_name (op->src [i]->type ));
912- GGML_ASSERT (false );
913- }
914- }
915-
916911 const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm ;
917912 const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction ;
918913
@@ -1002,10 +997,16 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1002997 return false ;
1003998 }
1004999 case GGML_TYPE_F16:
1005- case GGML_TYPE_BF16:
10061000 switch (op->type ) {
10071001 case GGML_TYPE_F32:
10081002 case GGML_TYPE_F16:
1003+ return true ;
1004+ default :
1005+ return false ;
1006+ }
1007+ case GGML_TYPE_BF16:
1008+ switch (op->type ) {
1009+ case GGML_TYPE_F32:
10091010 case GGML_TYPE_BF16:
10101011 return true ;
10111012 default :
@@ -2203,12 +2204,12 @@ static void ggml_metal_encode_node(
22032204 if ([device supportsFamily: MTLGPUFamilyApple7] &&
22042205 ne00 % 32 == 0 && ne00 >= 64 &&
22052206 dst_rows > dst_rows_min) {
2206-
22072207 // some Metal matrix data types require aligned pointers
22082208 // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
22092209 switch (src0->type ) {
2210- case GGML_TYPE_F32: GGML_ASSERT (nb01 % 16 == 0 ); break ;
2211- case GGML_TYPE_F16: GGML_ASSERT (nb01 % 8 == 0 ); break ;
2210+ case GGML_TYPE_F32: GGML_ASSERT (nb01 % 16 == 0 ); break ;
2211+ case GGML_TYPE_F16: GGML_ASSERT (nb01 % 8 == 0 ); break ;
2212+ case GGML_TYPE_BF16: GGML_ASSERT (nb01 % 8 == 0 ); break ;
22122213 default : break ;
22132214 }
22142215
@@ -2217,6 +2218,7 @@ static void ggml_metal_encode_node(
22172218 switch (src0->type ) {
22182219 case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline ; break ;
22192220 case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline ; break ;
2221+ case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline ; break ;
22202222 case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline ; break ;
22212223 case GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline ; break ;
22222224 case GGML_TYPE_Q5_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline ; break ;
@@ -2286,6 +2288,13 @@ static void ggml_metal_encode_node(
22862288 nth1 = 1 ;
22872289 pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline ;
22882290 } break ;
2291+ case GGML_TYPE_BF16:
2292+ {
2293+ GGML_ASSERT (src1t == GGML_TYPE_F32);
2294+ nth0 = 32 ;
2295+ nth1 = 1 ;
2296+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline ;
2297+ } break ;
22892298 case GGML_TYPE_Q4_0:
22902299 {
22912300 nth0 = 8 ;
@@ -3305,6 +3314,7 @@ static void ggml_metal_encode_node(
33053314 {
33063315 switch (dstt) {
33073316 case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline ; break ;
3317+ case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline ; break ;
33083318 default : GGML_ASSERT (false && " not implemented" );
33093319 };
33103320 } break ;
0 commit comments