@@ -415,6 +415,13 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
415415    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
416416    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
417417    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
418+     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
419+     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64,
420+     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64,
421+     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64,
422+     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64,
423+     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64,
424+     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64,
418425    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
419426    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,
420427    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,
@@ -1362,6 +1369,13 @@ @implementation GGMLMetalClass
13621369        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
13631370        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,        flash_attn_ext_q8_0_h256,        has_simdgroup_mm);
13641371        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
1372+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,      flash_attn_ext_vec_f16_h64,      has_simdgroup_reduction);
1373+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64,     flash_attn_ext_vec_bf16_h64,     has_simdgroup_reduction && use_bfloat);
1374+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64,     flash_attn_ext_vec_q4_0_h64,     has_simdgroup_reduction);
1375+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64,     flash_attn_ext_vec_q4_1_h64,     has_simdgroup_reduction);
1376+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64,     flash_attn_ext_vec_q5_0_h64,     has_simdgroup_reduction);
1377+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64,     flash_attn_ext_vec_q5_1_h64,     has_simdgroup_reduction);
1378+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64,     flash_attn_ext_vec_q8_0_h64,     has_simdgroup_reduction);
13651379        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,      flash_attn_ext_vec_f16_h96,      has_simdgroup_reduction);
13661380        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,     flash_attn_ext_vec_bf16_h96,     has_simdgroup_reduction && use_bfloat);
13671381        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,     flash_attn_ext_vec_q4_0_h96,     has_simdgroup_reduction);
@@ -4358,7 +4372,7 @@ static bool ggml_metal_encode_node(
43584372                //  TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
43594373                //        for now avoiding mainly to keep the number of templates/kernels a bit lower
43604374                //        these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
4361-                 if  (ne01 >= 20  || (ne00%128  != 0  && ne00 != 96  && ne00 != 192  && ne00 != 576 )) {
4375+                 if  (ne01 >= 20  || (ne00%128  != 0  && ne00 != 64  && ne00 !=  96  && ne00 != 192  && ne00 != 576 )) {
43624376                    switch  (src1->type ) {
43634377                        case  GGML_TYPE_F16:
43644378                            {
@@ -4539,6 +4553,24 @@ static bool ggml_metal_encode_node(
45394553                    use_vec_kernel = true ;
45404554
45414555                    switch  (ne00) {
4556+                         case  64 :
4557+                             {
4558+                                 switch  (src1->type ) {
4559+                                     case  GGML_TYPE_F16:  pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64].pipeline ; break ;
4560+                                     case  GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64].pipeline ; break ;
4561+                                     case  GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64].pipeline ; break ;
4562+                                     case  GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64].pipeline ; break ;
4563+                                     case  GGML_TYPE_Q5_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64].pipeline ; break ;
4564+                                     case  GGML_TYPE_Q5_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64].pipeline ; break ;
4565+                                     case  GGML_TYPE_Q8_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64].pipeline ; break ;
4566+                                     default :
4567+                                         {
4568+                                             GGML_LOG_ERROR (" unsupported type: %d \n "  , src1->type );
4569+                                             GGML_LOG_ERROR (" add template specialization for this type\n "  );
4570+                                             GGML_ABORT (" add template specialization for this type"  );
4571+                                         }
4572+                                 }
4573+                             } break ;
45424574                        case  96 :
45434575                            {
45444576                                switch  (src1->type ) {
0 commit comments