@@ -414,19 +414,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t
414414 return res;
415415}
416416
417- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int r1ptg) {
417+ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
418418 char base[256 ];
419419 char name[256 ];
420420
421421 snprintf (base, 256 , " kernel_mul_mv_ext_%s_%s_r1_%d" , ggml_type_name (tsrc0), ggml_type_name (tsrc1), r1ptg);
422- snprintf (name, 256 , " %s " , base);
422+ snprintf (name, 256 , " %s_nsg=%d_nxpsg=%d " , base, nsg, nxpsg );
423423
424424 ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline (lib, name);
425425 if (res) {
426426 return res;
427427 }
428428
429- res = ggml_metal_library_compile_pipeline (lib, base, name, nullptr );
429+ ggml_metal_cv_t cv = ggml_metal_cv_init ();
430+
431+ ggml_metal_cv_set_int16 (cv, nsg, FC_MUL_MV + 0 );
432+ ggml_metal_cv_set_int16 (cv, nxpsg, FC_MUL_MV + 1 );
433+
434+ res = ggml_metal_library_compile_pipeline (lib, base, name, cv);
435+
436+ ggml_metal_cv_free (cv);
430437
431438 return res;
432439}
@@ -608,7 +615,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
608615 };
609616
610617 snprintf (base, 256 , " kernel_mul_mv_%s_%s%s" , ggml_type_name (tsrc0), ggml_type_name (tsrc1), suffix);
611- snprintf (name, 256 , " %s " , base);
618+ snprintf (name, 256 , " %s_nsg=%d " , base, nsg );
612619
613620 ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline (lib, name);
614621 if (res) {
@@ -824,7 +831,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
824831 };
825832
826833 snprintf (base, 256 , " kernel_mul_mv_id_%s_%s%s" , ggml_type_name (tsrc0), ggml_type_name (tsrc1), suffix);
827- snprintf (name, 256 , " %s " , base);
834+ snprintf (name, 256 , " %s_nsg=%d " , base, nsg );
828835
829836 ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline (lib, name);
830837 if (res) {
@@ -923,11 +930,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
923930 dk,
924931 dv);
925932
926- snprintf (name, 256 , " kernel_%s_%s_dk%d_dv%d_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d" ,
927- " flash_attn_ext" ,
928- ggml_type_name (op->src [1 ]->type ),
929- dk,
930- dv,
933+ snprintf (name, 256 , " %s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d" ,
934+ base,
931935 has_mask,
932936 has_sinks,
933937 has_bias,
@@ -985,11 +989,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
985989 dk,
986990 dv);
987991
988- snprintf (name, 256 , " kernel_%s_%s_dk%d_dv%d_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d" ,
989- " flash_attn_ext_vec" ,
990- ggml_type_name (op->src [1 ]->type ),
991- dk,
992- dv,
992+ snprintf (name, 256 , " %s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d" ,
993+ base,
993994 has_mask,
994995 has_sinks,
995996 has_bias,
@@ -1033,7 +1034,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
10331034 char name[256 ];
10341035
10351036 snprintf (base, 256 , " kernel_flash_attn_ext_vec_reduce" );
1036- snprintf (name, 256 , " kernel_flash_attn_ext_vec_reduce_dv =%d_nwg=%d" , dv, nwg);
1037+ snprintf (name, 256 , " %s_dv =%d_nwg=%d" , base , dv, nwg);
10371038
10381039 ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline (lib, name);
10391040 if (res) {
0 commit comments