@@ -924,13 +924,58 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
924
924
return res;
925
925
}
926
926
927
+ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad (
928
+ ggml_metal_library_t lib,
929
+ const struct ggml_tensor * op,
930
+ bool has_mask,
931
+ int32_t ncpsg) {
932
+ assert (op->op == GGML_OP_FLASH_ATTN_EXT);
933
+ GGML_UNUSED (op);
934
+
935
+ char base[256 ];
936
+ char name[256 ];
937
+
938
+ snprintf (base, 256 , " kernel_%s" ,
939
+ " flash_attn_ext_pad" );
940
+
941
+ snprintf (name, 256 , " %s_mask=%d_ncpsg=%d" ,
942
+ base,
943
+ has_mask,
944
+ ncpsg);
945
+
946
+ ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline (lib, name);
947
+ if (res) {
948
+ return res;
949
+ }
950
+
951
+ ggml_metal_cv_t cv = ggml_metal_cv_init ();
952
+
953
+ ggml_metal_cv_set_bool (cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0 );
954
+ // ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
955
+ // ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
956
+ // ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
957
+
958
+ // ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
959
+ // ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
960
+ // ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
961
+ // ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
962
+ ggml_metal_cv_set_int32 (cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 24 );
963
+
964
+ res = ggml_metal_library_compile_pipeline (lib, base, name, cv);
965
+
966
+ ggml_metal_cv_free (cv);
967
+
968
+ return res;
969
+ }
970
+
927
971
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext (
928
972
ggml_metal_library_t lib,
929
973
const ggml_tensor * op,
930
974
bool has_mask,
931
975
bool has_sinks,
932
976
bool has_bias,
933
977
bool has_scap,
978
+ bool has_kvpad,
934
979
int32_t nsg) {
935
980
assert (op->op == GGML_OP_FLASH_ATTN_EXT);
936
981
@@ -943,18 +988,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
943
988
const int32_t ns10 = op->src [1 ]->nb [1 ]/op->src [1 ]->nb [0 ];
944
989
const int32_t ns20 = op->src [2 ]->nb [1 ]/op->src [2 ]->nb [0 ];
945
990
991
+ // do bounds checks for the mask?
992
+ const bool bc_mask = op->src [3 ] && (op->src [3 ]->ne [1 ] % 8 != 0 );
993
+
946
994
snprintf (base, 256 , " kernel_%s_%s_dk%d_dv%d" ,
947
995
" flash_attn_ext" ,
948
996
ggml_type_name (op->src [1 ]->type ),
949
997
dk,
950
998
dv);
951
999
952
- snprintf (name, 256 , " %s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d" ,
1000
+ snprintf (name, 256 , " %s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=% d_ns10=%d_ns20=%d_nsg=%d" ,
953
1001
base,
954
1002
has_mask,
955
1003
has_sinks,
956
1004
has_bias,
957
1005
has_scap,
1006
+ has_kvpad,
1007
+ bc_mask,
958
1008
ns10,
959
1009
ns20,
960
1010
nsg);
@@ -970,6 +1020,9 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
970
1020
ggml_metal_cv_set_bool (cv, has_sinks, FC_FLASH_ATTN_EXT + 1 );
971
1021
ggml_metal_cv_set_bool (cv, has_bias, FC_FLASH_ATTN_EXT + 2 );
972
1022
ggml_metal_cv_set_bool (cv, has_scap, FC_FLASH_ATTN_EXT + 3 );
1023
+ ggml_metal_cv_set_bool (cv, has_kvpad, FC_FLASH_ATTN_EXT + 4 );
1024
+
1025
+ ggml_metal_cv_set_bool (cv, bc_mask, FC_FLASH_ATTN_EXT + 10 );
973
1026
974
1027
ggml_metal_cv_set_int32 (cv, ns10, FC_FLASH_ATTN_EXT + 20 );
975
1028
ggml_metal_cv_set_int32 (cv, ns20, FC_FLASH_ATTN_EXT + 21 );
@@ -989,6 +1042,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
989
1042
bool has_sinks,
990
1043
bool has_bias,
991
1044
bool has_scap,
1045
+ bool has_kvpad,
992
1046
int32_t nsg,
993
1047
int32_t nwg) {
994
1048
assert (op->op == GGML_OP_FLASH_ATTN_EXT);
@@ -1008,12 +1062,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
1008
1062
dk,
1009
1063
dv);
1010
1064
1011
- snprintf (name, 256 , " %s_mask=%d_sink=%d_bias=%d_softcap =%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d" ,
1065
+ snprintf (name, 256 , " %s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad =%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d" ,
1012
1066
base,
1013
1067
has_mask,
1014
1068
has_sinks,
1015
1069
has_bias,
1016
1070
has_scap,
1071
+ has_kvpad,
1017
1072
ns10,
1018
1073
ns20,
1019
1074
nsg, nwg);
@@ -1029,6 +1084,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
1029
1084
ggml_metal_cv_set_bool (cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1 );
1030
1085
ggml_metal_cv_set_bool (cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2 );
1031
1086
ggml_metal_cv_set_bool (cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3 );
1087
+ ggml_metal_cv_set_bool (cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4 );
1032
1088
1033
1089
ggml_metal_cv_set_int32 (cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20 );
1034
1090
ggml_metal_cv_set_int32 (cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21 );
0 commit comments