@@ -407,6 +407,16 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
407407    GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
408408    GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
409409    GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
410+     GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32,
411+     GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16,
412+     GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32,
413+     GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16,
414+     GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32,
415+     GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16,
416+     GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32,
417+     GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16,
418+     GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32,
419+     GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16,
410420    GGML_METAL_KERNEL_TYPE_CONCAT,
411421    GGML_METAL_KERNEL_TYPE_SQR,
412422    GGML_METAL_KERNEL_TYPE_SQRT,
@@ -1012,6 +1022,16 @@ @implementation GGMLMetalClass
10121022        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,                  cpy_f32_q5_0,                   true );
10131023        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,                  cpy_f32_q5_1,                   true );
10141024        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,                cpy_f32_iq4_nl,                 true );
1025+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32,                  cpy_q4_0_f32,                   true );
1026+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16,                  cpy_q4_0_f16,                   true );
1027+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32,                  cpy_q4_1_f32,                   true );
1028+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16,                  cpy_q4_1_f16,                   true );
1029+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32,                  cpy_q5_0_f32,                   true );
1030+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16,                  cpy_q5_0_f16,                   true );
1031+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32,                  cpy_q5_1_f32,                   true );
1032+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16,                  cpy_q5_1_f16,                   true );
1033+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32,                  cpy_q8_0_f32,                   true );
1034+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16,                  cpy_q8_0_f16,                   true );
10151035        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CONCAT,                        concat,                         true );
10161036        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SQR,                           sqr,                            true );
10171037        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SQRT,                          sqrt,                           true );
@@ -1287,6 +1307,18 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
12871307                            default :
12881308                                return  false ;
12891309                        }
1310+                     case  GGML_TYPE_Q4_0:
1311+                     case  GGML_TYPE_Q4_1:
1312+                     case  GGML_TYPE_Q5_0:
1313+                     case  GGML_TYPE_Q5_1:
1314+                     case  GGML_TYPE_Q8_0:
1315+                         switch  (op->type ) {
1316+                             case  GGML_TYPE_F32:
1317+                             case  GGML_TYPE_F16:
1318+                                 return  true ;
1319+                             default :
1320+                                 return  false ;
1321+                         }
12901322                    default :
12911323                        return  false ;
12921324                };
@@ -3899,10 +3931,6 @@ static void ggml_metal_encode_node(
38993931        case  GGML_OP_CPY:
39003932        case  GGML_OP_CONT:
39013933            {
3902-                 GGML_ASSERT (ne00 % ggml_blck_size (src0->type ) == 0 );
3903- 
3904-                 int  nth = MIN (1024 , ne00/ggml_blck_size (src0->type ));
3905- 
39063934                id <MTLComputePipelineState > pipeline = nil ;
39073935
39083936                switch  (src0t) {
@@ -3936,7 +3964,47 @@ static void ggml_metal_encode_node(
39363964                            switch  (dstt) {
39373965                                case  GGML_TYPE_F32:  pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline ; break ;
39383966                                case  GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline ; break ;
3939-                                 default : GGML_ASSERT (false  && " not implemented"  );
3967+                                 default : GGML_ABORT (" not implemented"  );
3968+                             };
3969+                         } break ;
3970+                     case  GGML_TYPE_Q4_0:
3971+                         {
3972+                             switch  (dstt) {
3973+                                 case  GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline ; break ;
3974+                                 case  GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16].pipeline ; break ;
3975+                                 default : GGML_ABORT (" not implemented"  );
3976+                             };
3977+                         } break ;
3978+                     case  GGML_TYPE_Q4_1:
3979+                         {
3980+                             switch  (dstt) {
3981+                                 case  GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline ; break ;
3982+                                 case  GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16].pipeline ; break ;
3983+                                 default : GGML_ABORT (" not implemented"  );
3984+                             };
3985+                         } break ;
3986+                     case  GGML_TYPE_Q5_0:
3987+                         {
3988+                             switch  (dstt) {
3989+                                 case  GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline ; break ;
3990+                                 case  GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16].pipeline ; break ;
3991+                                 default : GGML_ABORT (" not implemented"  );
3992+                             };
3993+                         } break ;
3994+                     case  GGML_TYPE_Q5_1:
3995+                         {
3996+                             switch  (dstt) {
3997+                                 case  GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline ; break ;
3998+                                 case  GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16].pipeline ; break ;
3999+                                 default : GGML_ABORT (" not implemented"  );
4000+                             };
4001+                         } break ;
4002+                     case  GGML_TYPE_Q8_0:
4003+                         {
4004+                             switch  (dstt) {
4005+                                 case  GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline ; break ;
4006+                                 case  GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16].pipeline ; break ;
4007+                                 default : GGML_ABORT (" not implemented"  );
39404008                            };
39414009                        } break ;
39424010                    default : GGML_ABORT (" not implemented"  );
@@ -3966,7 +4034,11 @@ static void ggml_metal_encode_node(
39664034                [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
39674035                [encoder setBuffer: id_dst  offset: offs_dst  atIndex: 2 ];
39684036
4037+                 GGML_ASSERT (ne00 % ggml_blck_size (src0->type ) == 0 );
4038+                 int  nth = MIN (1024 , ne00/ggml_blck_size (src0->type ));
4039+ 
39694040                [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
4041+ 
39704042            } break ;
39714043        case  GGML_OP_SET:
39724044            {
0 commit comments