@@ -523,13 +523,6 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
523523 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
524524 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
525525 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
526- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40,
527- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40,
528- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40,
529- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40,
530- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40,
531- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40,
532- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40,
533526 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
534527 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64,
535528 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64,
@@ -1562,13 +1555,6 @@ @implementation GGMLMetalClass
15621555 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
15631556 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
15641557 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
1565- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40, flash_attn_ext_vec_f16_h40, has_simdgroup_reduction);
1566- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40, flash_attn_ext_vec_bf16_h40, has_simdgroup_reduction && use_bfloat);
1567- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40, flash_attn_ext_vec_q4_0_h40, has_simdgroup_reduction);
1568- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40, flash_attn_ext_vec_q4_1_h40, has_simdgroup_reduction);
1569- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40, flash_attn_ext_vec_q5_0_h40, has_simdgroup_reduction);
1570- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40, flash_attn_ext_vec_q5_1_h40, has_simdgroup_reduction);
1571- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40, flash_attn_ext_vec_q8_0_h40, has_simdgroup_reduction);
15721558 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction);
15731559 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat);
15741560 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction);
@@ -1909,9 +1895,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
19091895 case GGML_OP_ARANGE:
19101896 return true ;
19111897 case GGML_OP_FLASH_ATTN_EXT:
1912- if (op->src [0 ]->ne [0 ] == 32 ) {
1913- // head size == 32 (e.g. bert-bge-small)
1914- // TODO: not sure if it is worth adding kernels for this size
1898+ // for new head sizes, add checks here
1899+ if (op->src [0 ]->ne [0 ] != 40 &&
1900+ op->src [0 ]->ne [0 ] != 64 &&
1901+ op->src [0 ]->ne [0 ] != 80 &&
1902+ op->src [0 ]->ne [0 ] != 96 &&
1903+ op->src [0 ]->ne [0 ] != 112 &&
1904+ op->src [0 ]->ne [0 ] != 128 &&
1905+ op->src [0 ]->ne [0 ] != 192 &&
1906+ op->src [0 ]->ne [0 ] != 256 ) {
19151907 return false ;
19161908 }
19171909 if (op->src [0 ]->ne [0 ] == 576 ) {
@@ -5138,10 +5130,8 @@ static int ggml_metal_encode_node(
51385130
51395131 bool use_vec_kernel = false ;
51405132
5141- // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
5142- // for now avoiding mainly to keep the number of templates/kernels a bit lower
5143- // these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
5144- if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 64 && ne00 != 96 && ne00 != 192 && ne00 != 576 )) {
5133+ // use non-vec kernel if the batch size is large or if the vec-kernel is not supported for this head size
5134+ if (ne01 >= 20 || (ne00 == 40 || ne00 == 80 || ne00 == 112 )) {
51455135 switch (src1->type ) {
51465136 case GGML_TYPE_F16:
51475137 {
@@ -5329,24 +5319,6 @@ static int ggml_metal_encode_node(
53295319 use_vec_kernel = true ;
53305320
53315321 switch (ne00) {
5332- case 40 :
5333- {
5334- switch (src1->type ) {
5335- case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40].pipeline ; break ;
5336- case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40].pipeline ; break ;
5337- case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40].pipeline ; break ;
5338- case GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40].pipeline ; break ;
5339- case GGML_TYPE_Q5_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40].pipeline ; break ;
5340- case GGML_TYPE_Q5_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40].pipeline ; break ;
5341- case GGML_TYPE_Q8_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40].pipeline ; break ;
5342- default :
5343- {
5344- GGML_LOG_ERROR (" unsupported type: %d \n " , src1->type );
5345- GGML_LOG_ERROR (" add template specialization for this type\n " );
5346- GGML_ABORT (" add template specialization for this type" );
5347- }
5348- }
5349- } break ;
53505322 case 64 :
53515323 {
53525324 switch (src1->type ) {
0 commit comments