@@ -338,7 +338,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
338338 char base[256 ];
339339 char name[256 ];
340340
341- snprintf (base, 256 , " kernel_ssm_conv_%s_%s" , ggml_type_name (op->src [0 ]->type ), ggml_type_name (op->src [1 ]->type ));
341+ const char * suffix = " " ;
342+
343+ if (op->src [1 ]->ne [0 ] % 4 == 0 ) {
344+ suffix = " _4" ;
345+ }
346+
347+ snprintf (base, 256 , " kernel_ssm_conv_%s_%s%s" , ggml_type_name (op->src [0 ]->type ), ggml_type_name (op->src [1 ]->type ), suffix);
342348 snprintf (name, 256 , " %s" , base);
343349
344350 ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline (lib, name);
@@ -352,15 +358,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
352358}
353359
354360ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const ggml_tensor * op) {
361+ GGML_TENSOR_LOCALS ( int32_t , ne0, op->src [0 ], ne);
362+
355363 char base[256 ];
356364 char name[256 ];
357365
358- if (op->src [3 ]->ne [0 ] == 1 ) {
359- snprintf (base, 256 , " kernel_ssm_scan_group_%s" , ggml_type_name (op->src [0 ]->type ));
360- } else {
361- snprintf (base, 256 , " kernel_ssm_scan_%s" , ggml_type_name (op->src [0 ]->type ));
362- }
363- snprintf (name, 256 , " %s" , base);
366+ const int nsg = (ne00 + 31 )/32 ;
367+
368+ snprintf (base, 256 , " kernel_ssm_scan_%s" , ggml_type_name (op->src [0 ]->type ));
369+ snprintf (name, 256 , " %s_nsg=%d" , base, nsg);
364370
365371 ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline (lib, name);
366372 if (res) {
@@ -369,7 +375,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_librar
369375
370376 res = ggml_metal_library_compile_pipeline (lib, base, name, nullptr );
371377
372- ggml_metal_pipeline_set_smem (res, 32 *sizeof (float ));
378+ ggml_metal_pipeline_set_smem (res, 32 *sizeof (float )*nsg );
373379
374380 return res;
375381}
@@ -918,13 +924,104 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
918924 return res;
919925}
920926
927+ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad (
928+ ggml_metal_library_t lib,
929+ const struct ggml_tensor * op,
930+ bool has_mask,
931+ int32_t ncpsg) {
932+ assert (op->op == GGML_OP_FLASH_ATTN_EXT);
933+ GGML_UNUSED (op);
934+
935+ char base[256 ];
936+ char name[256 ];
937+
938+ snprintf (base, 256 , " kernel_%s" ,
939+ " flash_attn_ext_pad" );
940+
941+ snprintf (name, 256 , " %s_mask=%d_ncpsg=%d" ,
942+ base,
943+ has_mask,
944+ ncpsg);
945+
946+ ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline (lib, name);
947+ if (res) {
948+ return res;
949+ }
950+
951+ ggml_metal_cv_t cv = ggml_metal_cv_init ();
952+
953+ ggml_metal_cv_set_bool (cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0 );
954+ // ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
955+ // ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
956+ // ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
957+
958+ // ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
959+ // ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
960+ // ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
961+ // ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
962+ // ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24);
963+ ggml_metal_cv_set_int32 (cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25 );
964+
965+ res = ggml_metal_library_compile_pipeline (lib, base, name, cv);
966+
967+ ggml_metal_cv_free (cv);
968+
969+ return res;
970+ }
971+
972+ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk (
973+ ggml_metal_library_t lib,
974+ const struct ggml_tensor * op,
975+ int32_t nqptg,
976+ int32_t ncpsg) {
977+ assert (op->op == GGML_OP_FLASH_ATTN_EXT);
978+ GGML_UNUSED (op);
979+
980+ char base[256 ];
981+ char name[256 ];
982+
983+ snprintf (base, 256 , " kernel_%s" ,
984+ " flash_attn_ext_blk" );
985+
986+ snprintf (name, 256 , " %s_nqptg=%d_ncpsg=%d" ,
987+ base,
988+ nqptg,
989+ ncpsg);
990+
991+ ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline (lib, name);
992+ if (res) {
993+ return res;
994+ }
995+
996+ ggml_metal_cv_t cv = ggml_metal_cv_init ();
997+
998+ // ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0);
999+ // ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1);
1000+ // ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2);
1001+ // ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3);
1002+
1003+ // ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20);
1004+ // ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21);
1005+ // ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22);
1006+ // ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23);
1007+ ggml_metal_cv_set_int32 (cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24 );
1008+ ggml_metal_cv_set_int32 (cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25 );
1009+
1010+ res = ggml_metal_library_compile_pipeline (lib, base, name, cv);
1011+
1012+ ggml_metal_cv_free (cv);
1013+
1014+ return res;
1015+ }
1016+
9211017ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext (
9221018 ggml_metal_library_t lib,
9231019 const ggml_tensor * op,
9241020 bool has_mask,
9251021 bool has_sinks,
9261022 bool has_bias,
9271023 bool has_scap,
1024+ bool has_kvpad,
9281025 int32_t nsg) {
9291026 assert (op->op == GGML_OP_FLASH_ATTN_EXT);
9301027
@@ -937,18 +1034,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
9371034 const int32_t ns10 = op->src [1 ]->nb [1 ]/op->src [1 ]->nb [0 ];
9381035 const int32_t ns20 = op->src [2 ]->nb [1 ]/op->src [2 ]->nb [0 ];
9391036
1037+ // do bounds checks for the mask?
1038+ const bool bc_mask = op->src [3 ] && (op->src [3 ]->ne [1 ] % 8 != 0 );
1039+
9401040 snprintf (base, 256 , " kernel_%s_%s_dk%d_dv%d" ,
9411041 " flash_attn_ext" ,
9421042 ggml_type_name (op->src [1 ]->type ),
9431043 dk,
9441044 dv);
9451045
946- snprintf (name, 256 , " %s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d" ,
1046+ snprintf (name, 256 , " %s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=% d_ns10=%d_ns20=%d_nsg=%d" ,
9471047 base,
9481048 has_mask,
9491049 has_sinks,
9501050 has_bias,
9511051 has_scap,
1052+ has_kvpad,
1053+ bc_mask,
9521054 ns10,
9531055 ns20,
9541056 nsg);
@@ -964,6 +1066,9 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
9641066 ggml_metal_cv_set_bool (cv, has_sinks, FC_FLASH_ATTN_EXT + 1 );
9651067 ggml_metal_cv_set_bool (cv, has_bias, FC_FLASH_ATTN_EXT + 2 );
9661068 ggml_metal_cv_set_bool (cv, has_scap, FC_FLASH_ATTN_EXT + 3 );
1069+ ggml_metal_cv_set_bool (cv, has_kvpad, FC_FLASH_ATTN_EXT + 4 );
1070+
1071+ ggml_metal_cv_set_bool (cv, bc_mask, FC_FLASH_ATTN_EXT + 10 );
9671072
9681073 ggml_metal_cv_set_int32 (cv, ns10, FC_FLASH_ATTN_EXT + 20 );
9691074 ggml_metal_cv_set_int32 (cv, ns20, FC_FLASH_ATTN_EXT + 21 );
@@ -983,6 +1088,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
9831088 bool has_sinks,
9841089 bool has_bias,
9851090 bool has_scap,
1091+ bool has_kvpad,
9861092 int32_t nsg,
9871093 int32_t nwg) {
9881094 assert (op->op == GGML_OP_FLASH_ATTN_EXT);
@@ -1002,12 +1108,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
10021108 dk,
10031109 dv);
10041110
1005- snprintf (name, 256 , " %s_mask=%d_sink=%d_bias=%d_softcap =%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d" ,
1111+ snprintf (name, 256 , " %s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad =%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d" ,
10061112 base,
10071113 has_mask,
10081114 has_sinks,
10091115 has_bias,
10101116 has_scap,
1117+ has_kvpad,
10111118 ns10,
10121119 ns20,
10131120 nsg, nwg);
@@ -1023,6 +1130,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
10231130 ggml_metal_cv_set_bool (cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1 );
10241131 ggml_metal_cv_set_bool (cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2 );
10251132 ggml_metal_cv_set_bool (cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3 );
1133+ ggml_metal_cv_set_bool (cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4 );
10261134
10271135 ggml_metal_cv_set_int32 (cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20 );
10281136 ggml_metal_cv_set_int32 (cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21 );
0 commit comments