@@ -918,13 +918,58 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
918
918
return res;
919
919
}
920
920
921
+ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad (
922
+ ggml_metal_library_t lib,
923
+ const struct ggml_tensor * op,
924
+ bool has_mask,
925
+ int32_t ncpsg) {
926
+ assert (op->op == GGML_OP_FLASH_ATTN_EXT);
927
+ GGML_UNUSED (op);
928
+
929
+ char base[256 ];
930
+ char name[256 ];
931
+
932
+ snprintf (base, 256 , " kernel_%s" ,
933
+ " flash_attn_ext_pad" );
934
+
935
+ snprintf (name, 256 , " %s_mask=%d_ncpsg=%d" ,
936
+ base,
937
+ has_mask,
938
+ ncpsg);
939
+
940
+ ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline (lib, name);
941
+ if (res) {
942
+ return res;
943
+ }
944
+
945
+ ggml_metal_cv_t cv = ggml_metal_cv_init ();
946
+
947
+ ggml_metal_cv_set_bool (cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0 );
948
+ // ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
949
+ // ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
950
+ // ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
951
+
952
+ // ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
953
+ // ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
954
+ // ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
955
+ // ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
956
+ ggml_metal_cv_set_int32 (cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 24 );
957
+
958
+ res = ggml_metal_library_compile_pipeline (lib, base, name, cv);
959
+
960
+ ggml_metal_cv_free (cv);
961
+
962
+ return res;
963
+ }
964
+
921
965
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext (
922
966
ggml_metal_library_t lib,
923
967
const ggml_tensor * op,
924
968
bool has_mask,
925
969
bool has_sinks,
926
970
bool has_bias,
927
971
bool has_scap,
972
+ bool has_kvpad,
928
973
int32_t nsg) {
929
974
assert (op->op == GGML_OP_FLASH_ATTN_EXT);
930
975
@@ -943,12 +988,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
943
988
dk,
944
989
dv);
945
990
946
- snprintf (name, 256 , " %s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d" ,
991
+ snprintf (name, 256 , " %s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=% d_ns10=%d_ns20=%d_nsg=%d" ,
947
992
base,
948
993
has_mask,
949
994
has_sinks,
950
995
has_bias,
951
996
has_scap,
997
+ has_kvpad,
952
998
ns10,
953
999
ns20,
954
1000
nsg);
@@ -964,6 +1010,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
964
1010
ggml_metal_cv_set_bool (cv, has_sinks, FC_FLASH_ATTN_EXT + 1 );
965
1011
ggml_metal_cv_set_bool (cv, has_bias, FC_FLASH_ATTN_EXT + 2 );
966
1012
ggml_metal_cv_set_bool (cv, has_scap, FC_FLASH_ATTN_EXT + 3 );
1013
+ ggml_metal_cv_set_bool (cv, has_kvpad, FC_FLASH_ATTN_EXT + 4 );
967
1014
968
1015
ggml_metal_cv_set_int32 (cv, ns10, FC_FLASH_ATTN_EXT + 20 );
969
1016
ggml_metal_cv_set_int32 (cv, ns20, FC_FLASH_ATTN_EXT + 21 );
@@ -983,6 +1030,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
983
1030
bool has_sinks,
984
1031
bool has_bias,
985
1032
bool has_scap,
1033
+ bool has_kvpad,
986
1034
int32_t nsg,
987
1035
int32_t nwg) {
988
1036
assert (op->op == GGML_OP_FLASH_ATTN_EXT);
@@ -1002,12 +1050,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
1002
1050
dk,
1003
1051
dv);
1004
1052
1005
- snprintf (name, 256 , " %s_mask=%d_sink=%d_bias=%d_softcap =%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d" ,
1053
+ snprintf (name, 256 , " %s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad =%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d" ,
1006
1054
base,
1007
1055
has_mask,
1008
1056
has_sinks,
1009
1057
has_bias,
1010
1058
has_scap,
1059
+ has_kvpad,
1011
1060
ns10,
1012
1061
ns20,
1013
1062
nsg, nwg);
@@ -1023,6 +1072,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
1023
1072
ggml_metal_cv_set_bool (cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1 );
1024
1073
ggml_metal_cv_set_bool (cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2 );
1025
1074
ggml_metal_cv_set_bool (cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3 );
1075
+ ggml_metal_cv_set_bool (cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4 );
1026
1076
1027
1077
ggml_metal_cv_set_int32 (cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20 );
1028
1078
ggml_metal_cv_set_int32 (cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21 );
0 commit comments