5858 GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
5959 GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
6060 GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
61- GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
6261 GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16,
62+ GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
6363 GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
6464 GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
6565 GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
8484 GGML_METAL_KERNEL_TYPE_GROUP_NORM,
8585 GGML_METAL_KERNEL_TYPE_NORM,
8686 GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
87+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
88+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
89+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
90+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
8791 GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
8892 GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
8993 GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
132136 GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
133137 GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
134138 GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
139+ GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,
135140 GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
136141 GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
137142 GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
@@ -515,8 +520,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
515520 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true );
516521 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true );
517522 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true );
518- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true );
519523 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, true );
524+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true );
520525 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true );
521526 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true );
522527 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true );
@@ -541,6 +546,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
541546 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction );
542547 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_NORM, norm, true );
543548 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction );
549+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, ctx->support_simdgroup_reduction );
550+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, ctx->support_simdgroup_reduction );
551+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, ctx->support_simdgroup_reduction );
552+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, ctx->support_simdgroup_reduction );
544553 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction );
545554 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction );
546555 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction );
@@ -589,6 +598,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
589598 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction );
590599 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction );
591600 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm );
601+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, ctx->support_simdgroup_mm );
592602 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm );
593603 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm );
594604 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm );
@@ -739,7 +749,8 @@ static void ggml_metal_free(struct ggml_metal_context * ctx) {
739749static bool ggml_metal_supports_op (const struct ggml_metal_context * ctx, const struct ggml_tensor * op) {
740750 for (size_t i = 0 , n = 3 ; i < n; ++i) {
741751 if (op->src [i] != NULL && op->src [i]->type == GGML_TYPE_BF16 &&
742- op->op != GGML_OP_GET_ROWS) {
752+ op->op != GGML_OP_GET_ROWS &&
753+ op->op != GGML_OP_MUL_MAT) {
743754 printf (" op = %s , src[%zu ] = %s \n " , ggml_op_name (op->op ), i, ggml_type_name (op->src [i]->type ));
744755 GGML_ASSERT (false );
745756 }
@@ -1584,15 +1595,17 @@ static enum ggml_status ggml_metal_graph_compute(
15841595 // some Metal matrix data types require aligned pointers
15851596 // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
15861597 switch (src0->type ) {
1587- case GGML_TYPE_F32: GGML_ASSERT (nb01 % 16 == 0 ); break ;
1588- case GGML_TYPE_F16: GGML_ASSERT (nb01 % 8 == 0 ); break ;
1598+ case GGML_TYPE_F32: GGML_ASSERT (nb01 % 16 == 0 ); break ;
1599+ case GGML_TYPE_BF16: GGML_ASSERT (nb01 % 8 == 0 ); break ;
1600+ case GGML_TYPE_F16: GGML_ASSERT (nb01 % 8 == 0 ); break ;
15891601 default : break ;
15901602 }
15911603
15921604 id <MTLComputePipelineState > pipeline = nil ;
15931605
15941606 switch (src0->type ) {
15951607 case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline ; break ;
1608+ case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline ; break ;
15961609 case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline ; break ;
15971610 case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline ; break ;
15981611 case GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline ; break ;
@@ -1669,6 +1682,25 @@ static enum ggml_status ggml_metal_graph_compute(
16691682 nrows = 4 ;
16701683 }
16711684 } break ;
1685+ case GGML_TYPE_BF16:
1686+ {
1687+ nth0 = 32 ;
1688+ nth1 = 1 ;
1689+ if (src1t == GGML_TYPE_F32) {
1690+ if (ne11 * ne12 < 4 ) {
1691+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline ;
1692+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0 ) {
1693+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline ;
1694+ nrows = ne11;
1695+ } else {
1696+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline ;
1697+ nrows = 4 ;
1698+ }
1699+ } else {
1700+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline ;
1701+ nrows = 4 ;
1702+ }
1703+ } break ;
16721704 case GGML_TYPE_Q4_0:
16731705 {
16741706 nth0 = 8 ;
@@ -2165,8 +2197,8 @@ static enum ggml_status ggml_metal_graph_compute(
21652197
21662198 switch (src0->type ) {
21672199 case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline ; break ;
2168- case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline ; break ;
21692200 case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16 ].pipeline ; break ;
2201+ case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline ; break ;
21702202 case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline ; break ;
21712203 case GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline ; break ;
21722204 case GGML_TYPE_Q5_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline ; break ;
0 commit comments