@@ -415,6 +415,13 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
415415 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
416416 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
417417 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
418+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
419+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64,
420+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64,
421+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64,
422+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64,
423+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64,
424+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64,
418425 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
419426 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,
420427 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,
@@ -1362,6 +1369,13 @@ @implementation GGMLMetalClass
13621369 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
13631370 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
13641371 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
1372+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction);
1373+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat);
1374+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction);
1375+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64, flash_attn_ext_vec_q4_1_h64, has_simdgroup_reduction);
1376+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64, flash_attn_ext_vec_q5_0_h64, has_simdgroup_reduction);
1377+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64, flash_attn_ext_vec_q5_1_h64, has_simdgroup_reduction);
1378+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, flash_attn_ext_vec_q8_0_h64, has_simdgroup_reduction);
13651379 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction);
13661380 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat);
13671381 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction);
@@ -4358,7 +4372,7 @@ static bool ggml_metal_encode_node(
43584372 // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
43594373 // for now avoiding mainly to keep the number of templates/kernels a bit lower
43604374 // these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
4361- if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192 && ne00 != 576 )) {
4375+ if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 64 && ne00 != 96 && ne00 != 192 && ne00 != 576 )) {
43624376 switch (src1->type ) {
43634377 case GGML_TYPE_F16:
43644378 {
@@ -4539,6 +4553,24 @@ static bool ggml_metal_encode_node(
45394553 use_vec_kernel = true ;
45404554
45414555 switch (ne00) {
4556+ case 64 :
4557+ {
4558+ switch (src1->type ) {
4559+ case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64].pipeline ; break ;
4560+ case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64].pipeline ; break ;
4561+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64].pipeline ; break ;
4562+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64].pipeline ; break ;
4563+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64].pipeline ; break ;
4564+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64].pipeline ; break ;
4565+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64].pipeline ; break ;
4566+ default :
4567+ {
4568+ GGML_LOG_ERROR (" unsupported type: %d \n " , src1->type );
4569+ GGML_LOG_ERROR (" add template specialization for this type\n " );
4570+ GGML_ABORT (" add template specialization for this type" );
4571+ }
4572+ }
4573+ } break ;
45424574 case 96 :
45434575 {
45444576 switch (src1->type ) {
0 commit comments