@@ -402,6 +402,13 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
402402    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192,
403403    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
404404    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
405+     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
406+     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,
407+     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,
408+     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96,
409+     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96,
410+     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96,
411+     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96,
405412    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
406413    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
407414    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
@@ -1059,6 +1066,13 @@ @implementation GGMLMetalClass
10591066        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192,        flash_attn_ext_q8_0_h192,        has_simdgroup_mm);
10601067        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
10611068        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,        flash_attn_ext_q8_0_h256,        has_simdgroup_mm);
1069+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,      flash_attn_ext_vec_f16_h96,      has_simdgroup_reduction);
1070+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,     flash_attn_ext_vec_bf16_h96,     has_simdgroup_reduction && use_bfloat);
1071+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,     flash_attn_ext_vec_q4_0_h96,     has_simdgroup_reduction);
1072+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96,     flash_attn_ext_vec_q4_1_h96,     has_simdgroup_reduction);
1073+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96,     flash_attn_ext_vec_q5_0_h96,     has_simdgroup_reduction);
1074+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96,     flash_attn_ext_vec_q5_1_h96,     has_simdgroup_reduction);
1075+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96,     flash_attn_ext_vec_q8_0_h96,     has_simdgroup_reduction);
10621076        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,     flash_attn_ext_vec_f16_h128,     has_simdgroup_reduction);
10631077        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,    flash_attn_ext_vec_bf16_h128,    has_simdgroup_reduction && use_bfloat);
10641078        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,    flash_attn_ext_vec_q4_0_h128,    has_simdgroup_reduction);
@@ -3843,7 +3857,7 @@ static void ggml_metal_encode_node(
38433857                //  TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
38443858                //        for now avoiding mainly to keep the number of templates/kernels a bit lower
38453859                //        these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
3846-                 if  (ne01 >= 4  || (ne00%128  != 0  && ne00 != 192 )) {
3860+                 if  (ne01 >= 4  || (ne00%128  != 0  && ne00 != 96  && ne00 !=  192 )) {
38473861                    switch  (src1->type ) {
38483862                        case  GGML_TYPE_F16:
38493863                            {
@@ -4010,6 +4024,24 @@ static void ggml_metal_encode_node(
40104024                    use_vec_kernel = true ;
40114025
40124026                    switch  (ne00) {
4027+                         case  96 :
4028+                             {
4029+                                 switch  (src1->type ) {
4030+                                     case  GGML_TYPE_F16:  pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96].pipeline ; break ;
4031+                                     case  GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96].pipeline ; break ;
4032+                                     case  GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96].pipeline ; break ;
4033+                                     case  GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96].pipeline ; break ;
4034+                                     case  GGML_TYPE_Q5_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96].pipeline ; break ;
4035+                                     case  GGML_TYPE_Q5_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96].pipeline ; break ;
4036+                                     case  GGML_TYPE_Q8_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96].pipeline ; break ;
4037+                                     default :
4038+                                         {
4039+                                             GGML_LOG_ERROR (" unsupported type: %d \n "  , src1->type );
4040+                                             GGML_LOG_ERROR (" add template specialization for this type\n "  );
4041+                                             GGML_ABORT (" add template specialization for this type"  );
4042+                                         }
4043+                                 }
4044+                             } break ;
40134045                        case  128 :
40144046                            {
40154047                                switch  (src1->type ) {
0 commit comments