@@ -338,7 +338,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
338338 char base[256 ];
339339 char name[256 ];
340340
341- snprintf (base, 256 , " kernel_ssm_conv_%s_%s" , ggml_type_name (op->src [0 ]->type ), ggml_type_name (op->src [1 ]->type ));
341+ const char * suffix = " " ;
342+
343+ if (op->src [1 ]->ne [0 ] % 4 == 0 ) {
344+ suffix = " _4" ;
345+ }
346+
347+ snprintf (base, 256 , " kernel_ssm_conv_%s_%s%s" , ggml_type_name (op->src [0 ]->type ), ggml_type_name (op->src [1 ]->type ), suffix);
342348 snprintf (name, 256 , " %s" , base);
343349
344350 ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline (lib, name);
@@ -352,15 +358,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
352358}
353359
354360ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const ggml_tensor * op) {
361+ GGML_TENSOR_LOCALS ( int32_t , ne0, op->src [0 ], ne);
362+
355363 char base[256 ];
356364 char name[256 ];
357365
358- if (op->src [3 ]->ne [0 ] == 1 ) {
359- snprintf (base, 256 , " kernel_ssm_scan_group_%s" , ggml_type_name (op->src [0 ]->type ));
360- } else {
361- snprintf (base, 256 , " kernel_ssm_scan_%s" , ggml_type_name (op->src [0 ]->type ));
362- }
363- snprintf (name, 256 , " %s" , base);
366+ const int nsg = (ne00 + 31 )/32 ;
367+
368+ snprintf (base, 256 , " kernel_ssm_scan_%s" , ggml_type_name (op->src [0 ]->type ));
369+ snprintf (name, 256 , " %s_nsg=%d" , base, nsg);
364370
365371 ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline (lib, name);
366372 if (res) {
@@ -369,7 +375,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_librar
369375
370376 res = ggml_metal_library_compile_pipeline (lib, base, name, nullptr );
371377
372- ggml_metal_pipeline_set_smem (res, 32 *sizeof (float ));
378+ ggml_metal_pipeline_set_smem (res, 32 *sizeof (float )*nsg );
373379
374380 return res;
375381}
@@ -918,13 +924,58 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
918924 return res;
919925}
920926
927+ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad (
928+ ggml_metal_library_t lib,
929+ const struct ggml_tensor * op,
930+ bool has_mask,
931+ int32_t ncpsg) {
932+ assert (op->op == GGML_OP_FLASH_ATTN_EXT);
933+ GGML_UNUSED (op);
934+
935+ char base[256 ];
936+ char name[256 ];
937+
938+ snprintf (base, 256 , " kernel_%s" ,
939+ " flash_attn_ext_pad" );
940+
941+ snprintf (name, 256 , " %s_mask=%d_ncpsg=%d" ,
942+ base,
943+ has_mask,
944+ ncpsg);
945+
946+ ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline (lib, name);
947+ if (res) {
948+ return res;
949+ }
950+
951+ ggml_metal_cv_t cv = ggml_metal_cv_init ();
952+
953+ ggml_metal_cv_set_bool (cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0 );
954+ // ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
955+ // ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
956+ // ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
957+
958+ // ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
959+ // ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
960+ // ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
961+ // ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
962+ ggml_metal_cv_set_int32 (cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 24 );
963+
964+ res = ggml_metal_library_compile_pipeline (lib, base, name, cv);
965+
966+ ggml_metal_cv_free (cv);
967+
968+ return res;
969+ }
970+
921971ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext (
922972 ggml_metal_library_t lib,
923973 const ggml_tensor * op,
924974 bool has_mask,
925975 bool has_sinks,
926976 bool has_bias,
927977 bool has_scap,
978+ bool has_kvpad,
928979 int32_t nsg) {
929980 assert (op->op == GGML_OP_FLASH_ATTN_EXT);
930981
@@ -937,18 +988,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
937988 const int32_t ns10 = op->src [1 ]->nb [1 ]/op->src [1 ]->nb [0 ];
938989 const int32_t ns20 = op->src [2 ]->nb [1 ]/op->src [2 ]->nb [0 ];
939990
991+ // do bounds checks for the mask?
992+ const bool bc_mask = op->src [3 ] && (op->src [3 ]->ne [1 ] % 8 != 0 );
993+
940994 snprintf (base, 256 , " kernel_%s_%s_dk%d_dv%d" ,
941995 " flash_attn_ext" ,
942996 ggml_type_name (op->src [1 ]->type ),
943997 dk,
944998 dv);
945999
946- snprintf (name, 256 , " %s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d" ,
1000+ snprintf (name, 256 , " %s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=% d_ns10=%d_ns20=%d_nsg=%d" ,
9471001 base,
9481002 has_mask,
9491003 has_sinks,
9501004 has_bias,
9511005 has_scap,
1006+ has_kvpad,
1007+ bc_mask,
9521008 ns10,
9531009 ns20,
9541010 nsg);
@@ -964,6 +1020,9 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
9641020 ggml_metal_cv_set_bool (cv, has_sinks, FC_FLASH_ATTN_EXT + 1 );
9651021 ggml_metal_cv_set_bool (cv, has_bias, FC_FLASH_ATTN_EXT + 2 );
9661022 ggml_metal_cv_set_bool (cv, has_scap, FC_FLASH_ATTN_EXT + 3 );
1023+ ggml_metal_cv_set_bool (cv, has_kvpad, FC_FLASH_ATTN_EXT + 4 );
1024+
1025+ ggml_metal_cv_set_bool (cv, bc_mask, FC_FLASH_ATTN_EXT + 10 );
9671026
9681027 ggml_metal_cv_set_int32 (cv, ns10, FC_FLASH_ATTN_EXT + 20 );
9691028 ggml_metal_cv_set_int32 (cv, ns20, FC_FLASH_ATTN_EXT + 21 );
@@ -983,6 +1042,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
9831042 bool has_sinks,
9841043 bool has_bias,
9851044 bool has_scap,
1045+ bool has_kvpad,
9861046 int32_t nsg,
9871047 int32_t nwg) {
9881048 assert (op->op == GGML_OP_FLASH_ATTN_EXT);
@@ -1002,12 +1062,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
10021062 dk,
10031063 dv);
10041064
1005- snprintf (name, 256 , " %s_mask=%d_sink=%d_bias=%d_softcap =%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d" ,
1065+ snprintf (name, 256 , " %s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad =%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d" ,
10061066 base,
10071067 has_mask,
10081068 has_sinks,
10091069 has_bias,
10101070 has_scap,
1071+ has_kvpad,
10111072 ns10,
10121073 ns20,
10131074 nsg, nwg);
@@ -1023,6 +1084,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
10231084 ggml_metal_cv_set_bool (cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1 );
10241085 ggml_metal_cv_set_bool (cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2 );
10251086 ggml_metal_cv_set_bool (cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3 );
1087+ ggml_metal_cv_set_bool (cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4 );
10261088
10271089 ggml_metal_cv_set_int32 (cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20 );
10281090 ggml_metal_cv_set_int32 (cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21 );
0 commit comments