Skip to content

Commit fdc2bb1

Browse files
committed
metal : fattn quantization (wip)
1 parent 1926d6e commit fdc2bb1

File tree

2 files changed

+342
-16
lines changed

2 files changed

+342
-16
lines changed

ggml/src/ggml-metal.m

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
258258
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
259259
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
260260
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
261+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
261262
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
262263
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
263264
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
@@ -706,6 +707,7 @@ @implementation GGMLMetalClass
706707
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm);
707708
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, support_simdgroup_reduction);
708709
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction);
710+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, support_simdgroup_reduction);
709711
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
710712
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
711713
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
@@ -862,12 +864,6 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
862864
case GGML_OP_LEAKY_RELU:
863865
return true;
864866
case GGML_OP_FLASH_ATTN_EXT:
865-
if (op->src[1]->type != GGML_TYPE_F16) {
866-
return false;
867-
}
868-
if (op->src[2]->type != GGML_TYPE_F16) {
869-
return false;
870-
}
871867
if (op->src[0]->ne[0] == 256) {
872868
return false;
873869
}
@@ -2861,7 +2857,11 @@ static void ggml_metal_encode_node(
28612857

28622858
bool use_vec_kernel = false;
28632859

2864-
if (ne01 >= 4 || (ne00%128 != 0)) {
2860+
if (src1->type == GGML_TYPE_Q8_0 && src2->type == GGML_TYPE_Q8_0) {
2861+
use_vec_kernel = true;
2862+
2863+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline;
2864+
} else if (ne01 >= 4 || (ne00%128 != 0)) {
28652865
switch (ne00) {
28662866
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
28672867
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;

0 commit comments

Comments
 (0)