@@ -523,13 +523,6 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
523
523
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
524
524
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
525
525
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
526
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40,
527
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40,
528
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40,
529
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40,
530
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40,
531
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40,
532
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40,
533
526
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
534
527
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64,
535
528
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64,
@@ -1562,13 +1555,6 @@ @implementation GGMLMetalClass
1562
1555
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
1563
1556
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
1564
1557
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
1565
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40, flash_attn_ext_vec_f16_h40, has_simdgroup_reduction);
1566
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40, flash_attn_ext_vec_bf16_h40, has_simdgroup_reduction && use_bfloat);
1567
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40, flash_attn_ext_vec_q4_0_h40, has_simdgroup_reduction);
1568
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40, flash_attn_ext_vec_q4_1_h40, has_simdgroup_reduction);
1569
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40, flash_attn_ext_vec_q5_0_h40, has_simdgroup_reduction);
1570
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40, flash_attn_ext_vec_q5_1_h40, has_simdgroup_reduction);
1571
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40, flash_attn_ext_vec_q8_0_h40, has_simdgroup_reduction);
1572
1558
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction);
1573
1559
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat);
1574
1560
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction);
@@ -1909,9 +1895,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1909
1895
case GGML_OP_ARANGE:
1910
1896
return true ;
1911
1897
case GGML_OP_FLASH_ATTN_EXT:
1912
- if (op->src [0 ]->ne [0 ] == 32 ) {
1913
- // head size == 32 (e.g. bert-bge-small)
1914
- // TODO: not sure if it is worth adding kernels for this size
1898
+ // for new head sizes, add checks here
1899
+ if (op->src [0 ]->ne [0 ] != 40 &&
1900
+ op->src [0 ]->ne [0 ] != 64 &&
1901
+ op->src [0 ]->ne [0 ] != 80 &&
1902
+ op->src [0 ]->ne [0 ] != 96 &&
1903
+ op->src [0 ]->ne [0 ] != 112 &&
1904
+ op->src [0 ]->ne [0 ] != 128 &&
1905
+ op->src [0 ]->ne [0 ] != 192 &&
1906
+ op->src [0 ]->ne [0 ] != 256 ) {
1915
1907
return false ;
1916
1908
}
1917
1909
if (op->src [0 ]->ne [0 ] == 576 ) {
@@ -5138,10 +5130,8 @@ static int ggml_metal_encode_node(
5138
5130
5139
5131
bool use_vec_kernel = false ;
5140
5132
5141
- // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
5142
- // for now avoiding mainly to keep the number of templates/kernels a bit lower
5143
- // these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
5144
- if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 64 && ne00 != 96 && ne00 != 192 && ne00 != 576 )) {
5133
+ // use non-vec kernel if the batch size is large or if the vec-kernel is not supported for this head size
5134
+ if (ne01 >= 20 || (ne00 == 40 || ne00 == 80 || ne00 == 112 )) {
5145
5135
switch (src1->type ) {
5146
5136
case GGML_TYPE_F16:
5147
5137
{
@@ -5329,24 +5319,6 @@ static int ggml_metal_encode_node(
5329
5319
use_vec_kernel = true ;
5330
5320
5331
5321
switch (ne00) {
5332
- case 40 :
5333
- {
5334
- switch (src1->type ) {
5335
- case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40].pipeline ; break ;
5336
- case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40].pipeline ; break ;
5337
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40].pipeline ; break ;
5338
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40].pipeline ; break ;
5339
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40].pipeline ; break ;
5340
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40].pipeline ; break ;
5341
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40].pipeline ; break ;
5342
- default :
5343
- {
5344
- GGML_LOG_ERROR (" unsupported type: %d \n " , src1->type );
5345
- GGML_LOG_ERROR (" add template specialization for this type\n " );
5346
- GGML_ABORT (" add template specialization for this type" );
5347
- }
5348
- }
5349
- } break ;
5350
5322
case 64 :
5351
5323
{
5352
5324
switch (src1->type ) {
0 commit comments