@@ -258,6 +258,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
258258 // GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
259259 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
260260 // GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
261+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
261262 GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
262263 GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
263264 GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
@@ -706,6 +707,7 @@ @implementation GGMLMetalClass
706707 // GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm);
707708 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, support_simdgroup_reduction);
708709 // GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction);
710+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, support_simdgroup_reduction);
709711 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true );
710712 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true );
711713 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true );
@@ -862,12 +864,6 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
862864 case GGML_OP_LEAKY_RELU:
863865 return true ;
864866 case GGML_OP_FLASH_ATTN_EXT:
865- if (op->src [1 ]->type != GGML_TYPE_F16) {
866- return false ;
867- }
868- if (op->src [2 ]->type != GGML_TYPE_F16) {
869- return false ;
870- }
871867 if (op->src [0 ]->ne [0 ] == 256 ) {
872868 return false ;
873869 }
@@ -2861,7 +2857,11 @@ static void ggml_metal_encode_node(
28612857
28622858 bool use_vec_kernel = false ;
28632859
2864- if (ne01 >= 4 || (ne00%128 != 0 )) {
2860+ if (src1->type == GGML_TYPE_Q8_0 && src2->type == GGML_TYPE_Q8_0) {
2861+ use_vec_kernel = true ;
2862+
2863+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline ;
2864+ } else if (ne01 >= 4 || (ne00%128 != 0 )) {
28652865 switch (ne00) {
28662866 case 64 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline ; break ;
28672867 case 80 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline ; break ;
0 commit comments