@@ -924,13 +924,58 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
924924 return res;
925925}
926926
927+ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad (
928+ ggml_metal_library_t lib,
929+ const struct ggml_tensor * op,
930+ bool has_mask,
931+ int32_t ncpsg) {
932+ assert (op->op == GGML_OP_FLASH_ATTN_EXT);
933+ GGML_UNUSED (op);
934+
935+ char base[256 ];
936+ char name[256 ];
937+
938+ snprintf (base, 256 , " kernel_%s" ,
939+ " flash_attn_ext_pad" );
940+
941+ snprintf (name, 256 , " %s_mask=%d_ncpsg=%d" ,
942+ base,
943+ has_mask,
944+ ncpsg);
945+
946+ ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline (lib, name);
947+ if (res) {
948+ return res;
949+ }
950+
951+ ggml_metal_cv_t cv = ggml_metal_cv_init ();
952+
953+ ggml_metal_cv_set_bool (cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0 );
954+ // ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
955+ // ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
956+ // ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
957+
958+ // ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
959+ // ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
960+ // ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
961+ // ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
962+ ggml_metal_cv_set_int32 (cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 24 );
963+
964+ res = ggml_metal_library_compile_pipeline (lib, base, name, cv);
965+
966+ ggml_metal_cv_free (cv);
967+
968+ return res;
969+ }
970+
927971ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext (
928972 ggml_metal_library_t lib,
929973 const ggml_tensor * op,
930974 bool has_mask,
931975 bool has_sinks,
932976 bool has_bias,
933977 bool has_scap,
978+ bool has_kvpad,
934979 int32_t nsg) {
935980 assert (op->op == GGML_OP_FLASH_ATTN_EXT);
936981
@@ -943,18 +988,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
943988 const int32_t ns10 = op->src [1 ]->nb [1 ]/op->src [1 ]->nb [0 ];
944989 const int32_t ns20 = op->src [2 ]->nb [1 ]/op->src [2 ]->nb [0 ];
945990
991+ // do bounds checks for the mask?
992+ const bool bc_mask = op->src [3 ] && (op->src [3 ]->ne [1 ] % 8 != 0 );
993+
946994 snprintf (base, 256 , " kernel_%s_%s_dk%d_dv%d" ,
947995 " flash_attn_ext" ,
948996 ggml_type_name (op->src [1 ]->type ),
949997 dk,
950998 dv);
951999
952- snprintf (name, 256 , " %s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d" ,
1000+ snprintf (name, 256 , " %s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=% d_ns10=%d_ns20=%d_nsg=%d" ,
9531001 base,
9541002 has_mask,
9551003 has_sinks,
9561004 has_bias,
9571005 has_scap,
1006+ has_kvpad,
1007+ bc_mask,
9581008 ns10,
9591009 ns20,
9601010 nsg);
@@ -970,6 +1020,9 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
9701020 ggml_metal_cv_set_bool (cv, has_sinks, FC_FLASH_ATTN_EXT + 1 );
9711021 ggml_metal_cv_set_bool (cv, has_bias, FC_FLASH_ATTN_EXT + 2 );
9721022 ggml_metal_cv_set_bool (cv, has_scap, FC_FLASH_ATTN_EXT + 3 );
1023+ ggml_metal_cv_set_bool (cv, has_kvpad, FC_FLASH_ATTN_EXT + 4 );
1024+
1025+ ggml_metal_cv_set_bool (cv, bc_mask, FC_FLASH_ATTN_EXT + 10 );
9731026
9741027 ggml_metal_cv_set_int32 (cv, ns10, FC_FLASH_ATTN_EXT + 20 );
9751028 ggml_metal_cv_set_int32 (cv, ns20, FC_FLASH_ATTN_EXT + 21 );
@@ -989,6 +1042,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
9891042 bool has_sinks,
9901043 bool has_bias,
9911044 bool has_scap,
1045+ bool has_kvpad,
9921046 int32_t nsg,
9931047 int32_t nwg) {
9941048 assert (op->op == GGML_OP_FLASH_ATTN_EXT);
@@ -1008,12 +1062,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
10081062 dk,
10091063 dv);
10101064
1011- snprintf (name, 256 , " %s_mask=%d_sink=%d_bias=%d_softcap =%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d" ,
1065+ snprintf (name, 256 , " %s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad =%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d" ,
10121066 base,
10131067 has_mask,
10141068 has_sinks,
10151069 has_bias,
10161070 has_scap,
1071+ has_kvpad,
10171072 ns10,
10181073 ns20,
10191074 nsg, nwg);
@@ -1029,6 +1084,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
10291084 ggml_metal_cv_set_bool (cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1 );
10301085 ggml_metal_cv_set_bool (cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2 );
10311086 ggml_metal_cv_set_bool (cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3 );
1087+ ggml_metal_cv_set_bool (cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4 );
10321088
10331089 ggml_metal_cv_set_int32 (cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20 );
10341090 ggml_metal_cv_set_int32 (cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21 );
0 commit comments