@@ -930,13 +930,58 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
930930 return res;
931931}
932932
933+ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad (
934+ ggml_metal_library_t lib,
935+ const struct ggml_tensor * op,
936+ bool has_mask,
937+ int32_t ncpsg) {
938+ assert (op->op == GGML_OP_FLASH_ATTN_EXT);
939+ GGML_UNUSED (op);
940+
941+ char base[256 ];
942+ char name[256 ];
943+
944+ snprintf (base, 256 , " kernel_%s" ,
945+ " flash_attn_ext_pad" );
946+
947+ snprintf (name, 256 , " %s_mask=%d_ncpsg=%d" ,
948+ base,
949+ has_mask,
950+ ncpsg);
951+
952+ ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline (lib, name);
953+ if (res) {
954+ return res;
955+ }
956+
957+ ggml_metal_cv_t cv = ggml_metal_cv_init ();
958+
959+ ggml_metal_cv_set_bool (cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0 );
960+ // ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
961+ // ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
962+ // ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
963+
964+ // ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
965+ // ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
966+ // ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
967+ // ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
968+ ggml_metal_cv_set_int32 (cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 24 );
969+
970+ res = ggml_metal_library_compile_pipeline (lib, base, name, cv);
971+
972+ ggml_metal_cv_free (cv);
973+
974+ return res;
975+ }
976+
933977ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext (
934978 ggml_metal_library_t lib,
935979 const ggml_tensor * op,
936980 bool has_mask,
937981 bool has_sinks,
938982 bool has_bias,
939983 bool has_scap,
984+ bool has_kvpad,
940985 int32_t nsg) {
941986 assert (op->op == GGML_OP_FLASH_ATTN_EXT);
942987
@@ -955,12 +1000,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
9551000 dk,
9561001 dv);
9571002
958- snprintf (name, 256 , " %s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d" ,
1003+ snprintf (name, 256 , " %s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=% d_ns10=%d_ns20=%d_nsg=%d" ,
9591004 base,
9601005 has_mask,
9611006 has_sinks,
9621007 has_bias,
9631008 has_scap,
1009+ has_kvpad,
9641010 ns10,
9651011 ns20,
9661012 nsg);
@@ -976,6 +1022,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
9761022 ggml_metal_cv_set_bool (cv, has_sinks, FC_FLASH_ATTN_EXT + 1 );
9771023 ggml_metal_cv_set_bool (cv, has_bias, FC_FLASH_ATTN_EXT + 2 );
9781024 ggml_metal_cv_set_bool (cv, has_scap, FC_FLASH_ATTN_EXT + 3 );
1025+ ggml_metal_cv_set_bool (cv, has_kvpad, FC_FLASH_ATTN_EXT + 4 );
9791026
9801027 ggml_metal_cv_set_int32 (cv, ns10, FC_FLASH_ATTN_EXT + 20 );
9811028 ggml_metal_cv_set_int32 (cv, ns20, FC_FLASH_ATTN_EXT + 21 );
@@ -995,6 +1042,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
9951042 bool has_sinks,
9961043 bool has_bias,
9971044 bool has_scap,
1045+ bool has_kvpad,
9981046 int32_t nsg,
9991047 int32_t nwg) {
10001048 assert (op->op == GGML_OP_FLASH_ATTN_EXT);
@@ -1014,12 +1062,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
10141062 dk,
10151063 dv);
10161064
1017- snprintf (name, 256 , " %s_mask=%d_sink=%d_bias=%d_softcap =%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d" ,
1065+ snprintf (name, 256 , " %s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad =%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d" ,
10181066 base,
10191067 has_mask,
10201068 has_sinks,
10211069 has_bias,
10221070 has_scap,
1071+ has_kvpad,
10231072 ns10,
10241073 ns20,
10251074 nsg, nwg);
@@ -1035,6 +1084,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
10351084 ggml_metal_cv_set_bool (cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1 );
10361085 ggml_metal_cv_set_bool (cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2 );
10371086 ggml_metal_cv_set_bool (cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3 );
1087+ ggml_metal_cv_set_bool (cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4 );
10381088
10391089 ggml_metal_cv_set_int32 (cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20 );
10401090 ggml_metal_cv_set_int32 (cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21 );
0 commit comments