Skip to content

Commit 874ed80

Browse files
support metal for i2_s
1 parent 814d0ee commit 874ed80

File tree

3 files changed

+214
-54
lines changed

3 files changed

+214
-54
lines changed

ggml/src/ggml-common.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,12 @@ typedef struct {
267267
} block_q2_K;
268268
static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_half) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
269269

270+
typedef struct {
271+
uint8_t qs[QK_K/4]; // quants
272+
} block_i2_s;
273+
static_assert(sizeof(block_i2_s) == QK_K/4, "wrong gpu i2_s block size/padding");
274+
275+
270276
// 3-bit quantization
271277
// weight is represented as x = a * q
272278
// 16 blocks of 16 elements each

ggml/src/ggml-metal.m

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
151151
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
152152
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
153153
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
154+
GGML_METAL_KERNEL_TYPE_MUL_MV_I2_S_F32,
154155
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
155156
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
156157
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
@@ -196,6 +197,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
196197
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
197198
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
198199
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
200+
GGML_METAL_KERNEL_TYPE_MUL_MM_I2_S_F32,
199201
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
200202
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
201203
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
@@ -595,6 +597,7 @@ @implementation GGMLMetalClass
595597
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, support_simdgroup_reduction);
596598
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, support_simdgroup_reduction);
597599
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, support_simdgroup_reduction);
600+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_I2_S_F32, mul_mv_i2_s_f32, support_simdgroup_reduction);
598601
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, support_simdgroup_reduction);
599602
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, support_simdgroup_reduction);
600603
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, support_simdgroup_reduction);
@@ -640,6 +643,7 @@ @implementation GGMLMetalClass
640643
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, support_simdgroup_mm);
641644
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, support_simdgroup_mm);
642645
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, support_simdgroup_mm);
646+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_I2_S_F32, mul_mm_i2_s_f32, support_simdgroup_mm);
643647
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, support_simdgroup_mm);
644648
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, support_simdgroup_mm);
645649
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, support_simdgroup_mm);
@@ -1791,6 +1795,7 @@ static void ggml_metal_encode_node(
17911795
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
17921796
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
17931797
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
1798+
case GGML_TYPE_I2_S : pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_I2_S_F32 ].pipeline; break;
17941799
default: GGML_ABORT("MUL MAT-MAT not implemented");
17951800
}
17961801

@@ -1853,6 +1858,12 @@ static void ggml_metal_encode_node(
18531858
nth1 = 8;
18541859
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
18551860
} break;
1861+
case GGML_TYPE_I2_S:
1862+
{
1863+
nth0 = 2;
1864+
nth1 = 32;
1865+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_I2_S_F32].pipeline;
1866+
} break;
18561867
case GGML_TYPE_Q4_1:
18571868
{
18581869
nth0 = 8;
@@ -1989,7 +2000,7 @@ static void ggml_metal_encode_node(
19892000
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
19902001
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
19912002

1992-
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
2003+
if (src0t == GGML_TYPE_I2_S || src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
19932004
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
19942005
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
19952006
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];

0 commit comments

Comments
 (0)