@@ -407,6 +407,16 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
407407 GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
408408 GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
409409 GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
410+ GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32,
411+ GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16,
412+ GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32,
413+ GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16,
414+ GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32,
415+ GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16,
416+ GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32,
417+ GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16,
418+ GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32,
419+ GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16,
410420 GGML_METAL_KERNEL_TYPE_CONCAT,
411421 GGML_METAL_KERNEL_TYPE_SQR,
412422 GGML_METAL_KERNEL_TYPE_SQRT,
@@ -1012,6 +1022,16 @@ @implementation GGMLMetalClass
10121022 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true );
10131023 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true );
10141024 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true );
1025+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true );
1026+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, cpy_q4_0_f16, true );
1027+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true );
1028+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, cpy_q4_1_f16, true );
1029+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true );
1030+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, cpy_q5_0_f16, true );
1031+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true );
1032+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, cpy_q5_1_f16, true );
1033+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true );
1034+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, cpy_q8_0_f16, true );
10151035 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CONCAT, concat, true );
10161036 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SQR, sqr, true );
10171037 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true );
@@ -1287,6 +1307,18 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
12871307 default :
12881308 return false ;
12891309 }
1310+ case GGML_TYPE_Q4_0:
1311+ case GGML_TYPE_Q4_1:
1312+ case GGML_TYPE_Q5_0:
1313+ case GGML_TYPE_Q5_1:
1314+ case GGML_TYPE_Q8_0:
1315+ switch (op->type ) {
1316+ case GGML_TYPE_F32:
1317+ case GGML_TYPE_F16:
1318+ return true ;
1319+ default :
1320+ return false ;
1321+ }
12901322 default :
12911323 return false ;
12921324 };
@@ -3899,10 +3931,6 @@ static void ggml_metal_encode_node(
38993931 case GGML_OP_CPY:
39003932 case GGML_OP_CONT:
39013933 {
3902- GGML_ASSERT (ne00 % ggml_blck_size (src0->type ) == 0 );
3903-
3904- int nth = MIN (1024 , ne00/ggml_blck_size (src0->type ));
3905-
39063934 id <MTLComputePipelineState > pipeline = nil ;
39073935
39083936 switch (src0t) {
@@ -3936,7 +3964,47 @@ static void ggml_metal_encode_node(
39363964 switch (dstt) {
39373965 case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline ; break ;
39383966 case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline ; break ;
3939- default : GGML_ASSERT (false && " not implemented" );
3967+ default : GGML_ABORT (" not implemented" );
3968+ };
3969+ } break ;
3970+ case GGML_TYPE_Q4_0:
3971+ {
3972+ switch (dstt) {
3973+ case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline ; break ;
3974+ case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16].pipeline ; break ;
3975+ default : GGML_ABORT (" not implemented" );
3976+ };
3977+ } break ;
3978+ case GGML_TYPE_Q4_1:
3979+ {
3980+ switch (dstt) {
3981+ case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline ; break ;
3982+ case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16].pipeline ; break ;
3983+ default : GGML_ABORT (" not implemented" );
3984+ };
3985+ } break ;
3986+ case GGML_TYPE_Q5_0:
3987+ {
3988+ switch (dstt) {
3989+ case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline ; break ;
3990+ case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16].pipeline ; break ;
3991+ default : GGML_ABORT (" not implemented" );
3992+ };
3993+ } break ;
3994+ case GGML_TYPE_Q5_1:
3995+ {
3996+ switch (dstt) {
3997+ case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline ; break ;
3998+ case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16].pipeline ; break ;
3999+ default : GGML_ABORT (" not implemented" );
4000+ };
4001+ } break ;
4002+ case GGML_TYPE_Q8_0:
4003+ {
4004+ switch (dstt) {
4005+ case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline ; break ;
4006+ case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16].pipeline ; break ;
4007+ default : GGML_ABORT (" not implemented" );
39404008 };
39414009 } break ;
39424010 default : GGML_ABORT (" not implemented" );
@@ -3966,7 +4034,11 @@ static void ggml_metal_encode_node(
39664034 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
39674035 [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
39684036
4037+ GGML_ASSERT (ne00 % ggml_blck_size (src0->type ) == 0 );
4038+ int nth = MIN (1024 , ne00/ggml_blck_size (src0->type ));
4039+
39694040 [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
4041+
39704042 } break ;
39714043 case GGML_OP_SET:
39724044 {
0 commit comments