Skip to content

Commit 677ee9f

Browse files
committed
metal : add rest of types
ggml-ci
1 parent f45c40e commit 677ee9f

File tree

3 files changed

+360
-74
lines changed

3 files changed

+360
-74
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 86 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,22 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
199199
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
200200
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
201201
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
202+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,
203+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,
204+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,
205+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5,
206+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2,
207+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3,
208+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4,
209+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5,
210+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2,
211+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3,
212+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4,
213+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5,
214+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2,
215+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3,
216+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4,
217+
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5,
202218
GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
203219
GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
204220
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
@@ -747,6 +763,22 @@ @implementation GGMLMetalClass
747763
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
748764
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
749765
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
766+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction);
767+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction);
768+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction);
769+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, mul_mv_ext_q4_K_f32_r1_5, has_simdgroup_reduction);
770+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, mul_mv_ext_q5_K_f32_r1_2, has_simdgroup_reduction);
771+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, mul_mv_ext_q5_K_f32_r1_3, has_simdgroup_reduction);
772+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, mul_mv_ext_q5_K_f32_r1_4, has_simdgroup_reduction);
773+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, mul_mv_ext_q5_K_f32_r1_5, has_simdgroup_reduction);
774+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, mul_mv_ext_q6_K_f32_r1_2, has_simdgroup_reduction);
775+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, mul_mv_ext_q6_K_f32_r1_3, has_simdgroup_reduction);
776+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, mul_mv_ext_q6_K_f32_r1_4, has_simdgroup_reduction);
777+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, mul_mv_ext_q6_K_f32_r1_5, has_simdgroup_reduction);
778+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, mul_mv_ext_iq4_nl_f32_r1_2, has_simdgroup_reduction);
779+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, mul_mv_ext_iq4_nl_f32_r1_3, has_simdgroup_reduction);
780+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, mul_mv_ext_iq4_nl_f32_r1_4, has_simdgroup_reduction);
781+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, mul_mv_ext_iq4_nl_f32_r1_5, has_simdgroup_reduction);
750782
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
751783
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
752784
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
@@ -1978,17 +2010,28 @@ static void ggml_metal_encode_node(
19782010
// to the matrix-vector kernel
19792011
int ne11_mm_min = 4;
19802012

1981-
if ((src0t == GGML_TYPE_F16 || // TODO: helper function
1982-
src0t == GGML_TYPE_Q4_0 ||
1983-
src0t == GGML_TYPE_Q4_1 ||
1984-
src0t == GGML_TYPE_Q5_0 ||
1985-
src0t == GGML_TYPE_Q5_1 ||
1986-
src0t == GGML_TYPE_Q8_0
1987-
) &&
1988-
src1t == GGML_TYPE_F32 &&
1989-
(ne00%256 == 0) && // TODO: this can be relaxed to 128 for nxpsg == 8
1990-
(ne11 >= 2 && ne11 <= 8)) {
1991-
2013+
if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) &&
2014+
(
2015+
(
2016+
(
2017+
src0t == GGML_TYPE_F16 || // TODO: helper function
2018+
src0t == GGML_TYPE_Q4_0 ||
2019+
src0t == GGML_TYPE_Q4_1 ||
2020+
src0t == GGML_TYPE_Q5_0 ||
2021+
src0t == GGML_TYPE_Q5_1 ||
2022+
src0t == GGML_TYPE_Q8_0 ||
2023+
src0t == GGML_TYPE_IQ4_NL ||
2024+
false) && (ne11 >= 2 && ne11 <= 8)
2025+
) ||
2026+
(
2027+
(
2028+
src0t == GGML_TYPE_Q4_K ||
2029+
src0t == GGML_TYPE_Q5_K ||
2030+
src0t == GGML_TYPE_Q6_K ||
2031+
false) && (ne11 >= 4 && ne11 <= 12)
2032+
)
2033+
)
2034+
) {
19922035
// TODO: determine the optimal parameters based on grid utilization
19932036
const int nsg = 2; // TODO: or 4?
19942037
const int nxpsg = ne11 < 3 ? 16 : 8;
@@ -2010,9 +2053,6 @@ static void ggml_metal_encode_node(
20102053
r1ptg = 5; break;
20112054
};
20122055

2013-
assert(nxpsg >= 8);
2014-
assert(nxpsg%8 == 0);
2015-
20162056
id<MTLComputePipelineState> pipeline = nil;
20172057

20182058
switch (src0->type) {
@@ -2064,6 +2104,38 @@ static void ggml_metal_encode_node(
20642104
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break;
20652105
default: GGML_ABORT("not implemented");
20662106
} break;
2107+
case GGML_TYPE_Q4_K:
2108+
switch (r1ptg) {
2109+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break;
2110+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3].pipeline; break;
2111+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4].pipeline; break;
2112+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5].pipeline; break;
2113+
default: GGML_ABORT("not implemented");
2114+
} break;
2115+
case GGML_TYPE_Q5_K:
2116+
switch (r1ptg) {
2117+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2].pipeline; break;
2118+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3].pipeline; break;
2119+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4].pipeline; break;
2120+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5].pipeline; break;
2121+
default: GGML_ABORT("not implemented");
2122+
} break;
2123+
case GGML_TYPE_Q6_K:
2124+
switch (r1ptg) {
2125+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2].pipeline; break;
2126+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3].pipeline; break;
2127+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4].pipeline; break;
2128+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5].pipeline; break;
2129+
default: GGML_ABORT("not implemented");
2130+
} break;
2131+
case GGML_TYPE_IQ4_NL:
2132+
switch (r1ptg) {
2133+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2].pipeline; break;
2134+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3].pipeline; break;
2135+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4].pipeline; break;
2136+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5].pipeline; break;
2137+
default: GGML_ABORT("not implemented");
2138+
} break;
20672139
default: GGML_ABORT("not implemented");
20682140
}
20692141

0 commit comments

Comments
 (0)