@@ -930,13 +930,58 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
930
930
return res;
931
931
}
932
932
933
+ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad (
934
+ ggml_metal_library_t lib,
935
+ const struct ggml_tensor * op,
936
+ bool has_mask,
937
+ int32_t ncpsg) {
938
+ assert (op->op == GGML_OP_FLASH_ATTN_EXT);
939
+ GGML_UNUSED (op);
940
+
941
+ char base[256 ];
942
+ char name[256 ];
943
+
944
+ snprintf (base, 256 , " kernel_%s" ,
945
+ " flash_attn_ext_pad" );
946
+
947
+ snprintf (name, 256 , " %s_mask=%d_ncpsg=%d" ,
948
+ base,
949
+ has_mask,
950
+ ncpsg);
951
+
952
+ ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline (lib, name);
953
+ if (res) {
954
+ return res;
955
+ }
956
+
957
+ ggml_metal_cv_t cv = ggml_metal_cv_init ();
958
+
959
+ ggml_metal_cv_set_bool (cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0 );
960
+ // ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
961
+ // ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
962
+ // ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
963
+
964
+ // ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
965
+ // ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
966
+ // ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
967
+ // ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
968
+ ggml_metal_cv_set_int32 (cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 24 );
969
+
970
+ res = ggml_metal_library_compile_pipeline (lib, base, name, cv);
971
+
972
+ ggml_metal_cv_free (cv);
973
+
974
+ return res;
975
+ }
976
+
933
977
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext (
934
978
ggml_metal_library_t lib,
935
979
const ggml_tensor * op,
936
980
bool has_mask,
937
981
bool has_sinks,
938
982
bool has_bias,
939
983
bool has_scap,
984
+ bool has_kvpad,
940
985
int32_t nsg) {
941
986
assert (op->op == GGML_OP_FLASH_ATTN_EXT);
942
987
@@ -955,12 +1000,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
955
1000
dk,
956
1001
dv);
957
1002
958
- snprintf (name, 256 , " %s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d" ,
1003
+ snprintf (name, 256 , " %s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=% d_ns10=%d_ns20=%d_nsg=%d" ,
959
1004
base,
960
1005
has_mask,
961
1006
has_sinks,
962
1007
has_bias,
963
1008
has_scap,
1009
+ has_kvpad,
964
1010
ns10,
965
1011
ns20,
966
1012
nsg);
@@ -976,6 +1022,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
976
1022
ggml_metal_cv_set_bool (cv, has_sinks, FC_FLASH_ATTN_EXT + 1 );
977
1023
ggml_metal_cv_set_bool (cv, has_bias, FC_FLASH_ATTN_EXT + 2 );
978
1024
ggml_metal_cv_set_bool (cv, has_scap, FC_FLASH_ATTN_EXT + 3 );
1025
+ ggml_metal_cv_set_bool (cv, has_kvpad, FC_FLASH_ATTN_EXT + 4 );
979
1026
980
1027
ggml_metal_cv_set_int32 (cv, ns10, FC_FLASH_ATTN_EXT + 20 );
981
1028
ggml_metal_cv_set_int32 (cv, ns20, FC_FLASH_ATTN_EXT + 21 );
@@ -995,6 +1042,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
995
1042
bool has_sinks,
996
1043
bool has_bias,
997
1044
bool has_scap,
1045
+ bool has_kvpad,
998
1046
int32_t nsg,
999
1047
int32_t nwg) {
1000
1048
assert (op->op == GGML_OP_FLASH_ATTN_EXT);
@@ -1014,12 +1062,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
1014
1062
dk,
1015
1063
dv);
1016
1064
1017
- snprintf (name, 256 , " %s_mask=%d_sink=%d_bias=%d_softcap =%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d" ,
1065
+ snprintf (name, 256 , " %s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad =%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d" ,
1018
1066
base,
1019
1067
has_mask,
1020
1068
has_sinks,
1021
1069
has_bias,
1022
1070
has_scap,
1071
+ has_kvpad,
1023
1072
ns10,
1024
1073
ns20,
1025
1074
nsg, nwg);
@@ -1035,6 +1084,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
1035
1084
ggml_metal_cv_set_bool (cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1 );
1036
1085
ggml_metal_cv_set_bool (cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2 );
1037
1086
ggml_metal_cv_set_bool (cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3 );
1087
+ ggml_metal_cv_set_bool (cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4 );
1038
1088
1039
1089
ggml_metal_cv_set_int32 (cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20 );
1040
1090
ggml_metal_cv_set_int32 (cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21 );
0 commit comments