@@ -34,6 +34,10 @@ ggml_metal_pipelines_t ggml_metal_pipelines_init(void) {
3434}
3535
3636void ggml_metal_pipelines_free (ggml_metal_pipelines_t ppls) {
37+ if (!ppls) {
38+ return ;
39+ }
40+
3741 for (auto it = ppls->data .begin (); it != ppls->data .end (); ++it) {
3842 ggml_metal_pipeline_free (it->second );
3943 }
@@ -410,19 +414,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t
410414 return res;
411415}
412416
413- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int r1ptg) {
417+ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
414418 char base[256 ];
415419 char name[256 ];
416420
417421 snprintf (base, 256 , " kernel_mul_mv_ext_%s_%s_r1_%d" , ggml_type_name (tsrc0), ggml_type_name (tsrc1), r1ptg);
418- snprintf (name, 256 , " %s " , base);
422+ snprintf (name, 256 , " %s_nsg=%d_nxpsg=%d " , base, nsg, nxpsg );
419423
420424 ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline (lib, name);
421425 if (res) {
422426 return res;
423427 }
424428
425- res = ggml_metal_library_compile_pipeline (lib, base, name, nullptr );
429+ ggml_metal_cv_t cv = ggml_metal_cv_init ();
430+
431+ ggml_metal_cv_set_int16 (cv, nsg, FC_MUL_MV + 0 );
432+ ggml_metal_cv_set_int16 (cv, nxpsg, FC_MUL_MV + 1 );
433+
434+ res = ggml_metal_library_compile_pipeline (lib, base, name, cv);
435+
436+ ggml_metal_cv_free (cv);
426437
427438 return res;
428439}
@@ -467,37 +478,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
467478 // use custom matrix x vector kernel
468479 switch (tsrc0) {
469480 case GGML_TYPE_F32:
481+ case GGML_TYPE_F16:
482+ case GGML_TYPE_BF16:
470483 {
471- GGML_ASSERT (op->src [1 ]->type == GGML_TYPE_F32);
472-
473- nsg = 1 ;
474- nr0 = 1 ;
475- nr1 = 4 ;
476484 if (ne00 == 4 ) {
485+ nsg = 1 ;
477486 nr0 = 32 ;
487+ nr1 = 4 ;
478488 suffix = " _c4" ;
479- }
480- } break ;
481- case GGML_TYPE_F16:
482- case GGML_TYPE_BF16:
483- {
484- nsg = 1 ;
485- nr0 = 1 ;
486- if (op->src [1 ]->type == GGML_TYPE_F32) {
487- if (ne00 == 4 ) {
488- nr0 = 32 ;
489- nr1 = 4 ;
490- suffix = " _c4" ;
491- } else if (ne11 * ne12 < 4 ) {
492- suffix = " _1row" ;
493- } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0 ) {
494- suffix = " _l4" ;
495- nr1 = ne11;
496- } else {
497- nr1 = 4 ;
498- }
489+ } else if (ne00 % 4 == 0 ) {
490+ nsg = N_SG_F;
491+ nr0 = N_R0_F;
492+ nr1 = 1 ;
493+ smem = 32 *sizeof (float )*N_R0_F;
494+ suffix = " _4" ;
499495 } else {
500- nr1 = 4 ;
496+ nsg = N_SG_F;
497+ nr0 = N_R0_F;
498+ nr1 = 1 ;
499+ smem = 32 *sizeof (float )*N_R0_F;
501500 }
502501 } break ;
503502 case GGML_TYPE_Q4_0:
@@ -616,14 +615,20 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
616615 };
617616
618617 snprintf (base, 256 , " kernel_mul_mv_%s_%s%s" , ggml_type_name (tsrc0), ggml_type_name (tsrc1), suffix);
619- snprintf (name, 256 , " %s " , base);
618+ snprintf (name, 256 , " %s_nsg=%d " , base, nsg );
620619
621620 ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline (lib, name);
622621 if (res) {
623622 return res;
624623 }
625624
626- res = ggml_metal_library_compile_pipeline (lib, base, name, nullptr );
625+ ggml_metal_cv_t cv = ggml_metal_cv_init ();
626+
627+ ggml_metal_cv_set_int16 (cv, nsg, FC_MUL_MV + 0 );
628+
629+ res = ggml_metal_library_compile_pipeline (lib, base, name, cv);
630+
631+ ggml_metal_cv_free (cv);
627632
628633 ggml_metal_pipeline_set_nr0 (res, nr0);
629634 ggml_metal_pipeline_set_nr1 (res, nr1);
@@ -689,25 +694,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
689694 const ggml_type tsrc0 = op->src [0 ]->type ;
690695 const ggml_type tsrc1 = op->src [1 ]->type ;
691696
697+ const char * suffix = " " ;
698+
692699 // use custom matrix x vector kernel
693700 switch (tsrc0) {
694701 case GGML_TYPE_F32:
695- {
696- GGML_ASSERT (op->src [1 ]->type == GGML_TYPE_F32);
697- nsg = 1 ;
698- nr0 = 1 ;
699- } break ;
700702 case GGML_TYPE_F16:
701- {
702- GGML_ASSERT (op->src [1 ]->type == GGML_TYPE_F32);
703- nsg = 1 ;
704- nr0 = 1 ;
705- } break ;
706703 case GGML_TYPE_BF16:
707704 {
708- GGML_ASSERT (op->src [1 ]->type == GGML_TYPE_F32);
709- nsg = 1 ;
710- nr0 = 1 ;
705+ if (ne00 % 4 == 0 ) {
706+ nsg = N_SG_F;
707+ nr0 = N_R0_F;
708+ nr1 = 1 ;
709+ smem = 32 *sizeof (float )*N_R0_F;
710+ suffix = " _4" ;
711+ } else {
712+ nsg = N_SG_F;
713+ nr0 = N_R0_F;
714+ nr1 = 1 ;
715+ smem = 32 *sizeof (float )*N_R0_F;
716+ }
711717 } break ;
712718 case GGML_TYPE_Q4_0:
713719 {
@@ -824,15 +830,21 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
824830 }
825831 };
826832
827- snprintf (base, 256 , " kernel_mul_mv_id_%s_%s" , ggml_type_name (tsrc0), ggml_type_name (tsrc1));
828- snprintf (name, 256 , " %s " , base);
833+ snprintf (base, 256 , " kernel_mul_mv_id_%s_%s%s " , ggml_type_name (tsrc0), ggml_type_name (tsrc1), suffix );
834+ snprintf (name, 256 , " %s_nsg=%d " , base, nsg );
829835
830836 ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline (lib, name);
831837 if (res) {
832838 return res;
833839 }
834840
835- res = ggml_metal_library_compile_pipeline (lib, base, name, nullptr );
841+ ggml_metal_cv_t cv = ggml_metal_cv_init ();
842+
843+ ggml_metal_cv_set_int16 (cv, nsg, FC_MUL_MV + 0 );
844+
845+ res = ggml_metal_library_compile_pipeline (lib, base, name, cv);
846+
847+ ggml_metal_cv_free (cv);
836848
837849 ggml_metal_pipeline_set_nr0 (res, nr0);
838850 ggml_metal_pipeline_set_nr1 (res, nr1);
@@ -918,11 +930,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
918930 dk,
919931 dv);
920932
921- snprintf (name, 256 , " kernel_%s_%s_dk%d_dv%d_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d" ,
922- " flash_attn_ext" ,
923- ggml_type_name (op->src [1 ]->type ),
924- dk,
925- dv,
933+ snprintf (name, 256 , " %s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d" ,
934+ base,
926935 has_mask,
927936 has_sinks,
928937 has_bias,
@@ -980,11 +989,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
980989 dk,
981990 dv);
982991
983- snprintf (name, 256 , " kernel_%s_%s_dk%d_dv%d_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d" ,
984- " flash_attn_ext_vec" ,
985- ggml_type_name (op->src [1 ]->type ),
986- dk,
987- dv,
992+ snprintf (name, 256 , " %s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d" ,
993+ base,
988994 has_mask,
989995 has_sinks,
990996 has_bias,
@@ -1028,7 +1034,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
10281034 char name[256 ];
10291035
10301036 snprintf (base, 256 , " kernel_flash_attn_ext_vec_reduce" );
1031- snprintf (name, 256 , " kernel_flash_attn_ext_vec_reduce_dv =%d_nwg=%d" , dv, nwg);
1037+ snprintf (name, 256 , " %s_dv =%d_nwg=%d" , base , dv, nwg);
10321038
10331039 ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline (lib, name);
10341040 if (res) {
0 commit comments