@@ -199,6 +199,22 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
199199 GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
200200 GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
201201 GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
202+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,
203+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,
204+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,
205+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5,
206+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2,
207+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3,
208+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4,
209+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5,
210+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2,
211+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3,
212+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4,
213+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5,
214+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2,
215+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3,
216+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4,
217+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5,
202218 GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
203219 GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
204220 GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
@@ -747,6 +763,22 @@ @implementation GGMLMetalClass
747763 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
748764 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
749765 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
766+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction);
767+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction);
768+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction);
769+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, mul_mv_ext_q4_K_f32_r1_5, has_simdgroup_reduction);
770+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, mul_mv_ext_q5_K_f32_r1_2, has_simdgroup_reduction);
771+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, mul_mv_ext_q5_K_f32_r1_3, has_simdgroup_reduction);
772+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, mul_mv_ext_q5_K_f32_r1_4, has_simdgroup_reduction);
773+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, mul_mv_ext_q5_K_f32_r1_5, has_simdgroup_reduction);
774+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, mul_mv_ext_q6_K_f32_r1_2, has_simdgroup_reduction);
775+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, mul_mv_ext_q6_K_f32_r1_3, has_simdgroup_reduction);
776+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, mul_mv_ext_q6_K_f32_r1_4, has_simdgroup_reduction);
777+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, mul_mv_ext_q6_K_f32_r1_5, has_simdgroup_reduction);
778+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, mul_mv_ext_iq4_nl_f32_r1_2, has_simdgroup_reduction);
779+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, mul_mv_ext_iq4_nl_f32_r1_3, has_simdgroup_reduction);
780+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, mul_mv_ext_iq4_nl_f32_r1_4, has_simdgroup_reduction);
781+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, mul_mv_ext_iq4_nl_f32_r1_5, has_simdgroup_reduction);
750782 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
751783 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
752784 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
@@ -1978,17 +2010,28 @@ static void ggml_metal_encode_node(
19782010 // to the matrix-vector kernel
19792011 int ne11_mm_min = 4 ;
19802012
1981- if ((src0t == GGML_TYPE_F16 || // TODO: helper function
1982- src0t == GGML_TYPE_Q4_0 ||
1983- src0t == GGML_TYPE_Q4_1 ||
1984- src0t == GGML_TYPE_Q5_0 ||
1985- src0t == GGML_TYPE_Q5_1 ||
1986- src0t == GGML_TYPE_Q8_0
1987- ) &&
1988- src1t == GGML_TYPE_F32 &&
1989- (ne00%256 == 0 ) && // TODO: this can be relaxed to 128 for nxpsg == 8
1990- (ne11 >= 2 && ne11 <= 8 )) {
1991-
2013+ if (src1t == GGML_TYPE_F32 && (ne00%256 == 0 ) &&
2014+ (
2015+ (
2016+ (
2017+ src0t == GGML_TYPE_F16 || // TODO: helper function
2018+ src0t == GGML_TYPE_Q4_0 ||
2019+ src0t == GGML_TYPE_Q4_1 ||
2020+ src0t == GGML_TYPE_Q5_0 ||
2021+ src0t == GGML_TYPE_Q5_1 ||
2022+ src0t == GGML_TYPE_Q8_0 ||
2023+ src0t == GGML_TYPE_IQ4_NL ||
2024+ false ) && (ne11 >= 2 && ne11 <= 8 )
2025+ ) ||
2026+ (
2027+ (
2028+ src0t == GGML_TYPE_Q4_K ||
2029+ src0t == GGML_TYPE_Q5_K ||
2030+ src0t == GGML_TYPE_Q6_K ||
2031+ false ) && (ne11 >= 4 && ne11 <= 12 )
2032+ )
2033+ )
2034+ ) {
19922035 // TODO: determine the optimal parameters based on grid utilization
19932036 const int nsg = 2 ; // TODO: or 4?
19942037 const int nxpsg = ne11 < 3 ? 16 : 8 ;
@@ -2010,9 +2053,6 @@ static void ggml_metal_encode_node(
20102053 r1ptg = 5 ; break ;
20112054 };
20122055
2013- assert (nxpsg >= 8 );
2014- assert (nxpsg%8 == 0 );
2015-
20162056 id <MTLComputePipelineState > pipeline = nil ;
20172057
20182058 switch (src0->type ) {
@@ -2064,6 +2104,38 @@ static void ggml_metal_encode_node(
20642104 case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline ; break ;
20652105 default : GGML_ABORT (" not implemented" );
20662106 } break ;
2107+ case GGML_TYPE_Q4_K:
2108+ switch (r1ptg) {
2109+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline ; break ;
2110+ case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3].pipeline ; break ;
2111+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4].pipeline ; break ;
2112+ case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5].pipeline ; break ;
2113+ default : GGML_ABORT (" not implemented" );
2114+ } break ;
2115+ case GGML_TYPE_Q5_K:
2116+ switch (r1ptg) {
2117+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2].pipeline ; break ;
2118+ case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3].pipeline ; break ;
2119+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4].pipeline ; break ;
2120+ case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5].pipeline ; break ;
2121+ default : GGML_ABORT (" not implemented" );
2122+ } break ;
2123+ case GGML_TYPE_Q6_K:
2124+ switch (r1ptg) {
2125+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2].pipeline ; break ;
2126+ case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3].pipeline ; break ;
2127+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4].pipeline ; break ;
2128+ case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5].pipeline ; break ;
2129+ default : GGML_ABORT (" not implemented" );
2130+ } break ;
2131+ case GGML_TYPE_IQ4_NL:
2132+ switch (r1ptg) {
2133+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2].pipeline ; break ;
2134+ case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3].pipeline ; break ;
2135+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4].pipeline ; break ;
2136+ case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5].pipeline ; break ;
2137+ default : GGML_ABORT (" not implemented" );
2138+ } break ;
20672139 default : GGML_ABORT (" not implemented" );
20682140 }
20692141
0 commit comments