@@ -34,6 +34,10 @@ ggml_metal_pipelines_t ggml_metal_pipelines_init(void) {
3434}
3535
3636void ggml_metal_pipelines_free (ggml_metal_pipelines_t ppls) {
37+ if (!ppls) {
38+ return ;
39+ }
40+
3741 for (auto it = ppls->data .begin (); it != ppls->data .end (); ++it) {
3842 ggml_metal_pipeline_free (it->second );
3943 }
@@ -467,37 +471,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
467471 // use custom matrix x vector kernel
468472 switch (tsrc0) {
469473 case GGML_TYPE_F32:
474+ case GGML_TYPE_F16:
475+ case GGML_TYPE_BF16:
470476 {
471- GGML_ASSERT (op->src [1 ]->type == GGML_TYPE_F32);
472-
473- nsg = 1 ;
474- nr0 = 1 ;
475- nr1 = 4 ;
476477 if (ne00 == 4 ) {
478+ nsg = 1 ;
477479 nr0 = 32 ;
480+ nr1 = 4 ;
478481 suffix = " _c4" ;
479- }
480- } break ;
481- case GGML_TYPE_F16:
482- case GGML_TYPE_BF16:
483- {
484- nsg = 1 ;
485- nr0 = 1 ;
486- if (op->src [1 ]->type == GGML_TYPE_F32) {
487- if (ne00 == 4 ) {
488- nr0 = 32 ;
489- nr1 = 4 ;
490- suffix = " _c4" ;
491- } else if (ne11 * ne12 < 4 ) {
492- suffix = " _1row" ;
493- } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0 ) {
494- suffix = " _l4" ;
495- nr1 = ne11;
496- } else {
497- nr1 = 4 ;
498- }
482+ } else if (ne00 % 4 == 0 ) {
483+ nsg = N_SG_F;
484+ nr0 = N_R0_F;
485+ nr1 = 1 ;
486+ smem = 32 *sizeof (float )*N_R0_F;
487+ suffix = " _4" ;
499488 } else {
500- nr1 = 4 ;
489+ nsg = N_SG_F;
490+ nr0 = N_R0_F;
491+ nr1 = 1 ;
492+ smem = 32 *sizeof (float )*N_R0_F;
501493 }
502494 } break ;
503495 case GGML_TYPE_Q4_0:
@@ -623,7 +615,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
623615 return res;
624616 }
625617
626- res = ggml_metal_library_compile_pipeline (lib, base, name, nullptr );
618+ ggml_metal_cv_t cv = ggml_metal_cv_init ();
619+
620+ ggml_metal_cv_set_int16 (cv, nsg, FC_MUL_MV + 0 );
621+
622+ res = ggml_metal_library_compile_pipeline (lib, base, name, cv);
623+
624+ ggml_metal_cv_free (cv);
627625
628626 ggml_metal_pipeline_set_nr0 (res, nr0);
629627 ggml_metal_pipeline_set_nr1 (res, nr1);
@@ -689,25 +687,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
689687 const ggml_type tsrc0 = op->src [0 ]->type ;
690688 const ggml_type tsrc1 = op->src [1 ]->type ;
691689
690+ const char * suffix = " " ;
691+
692692 // use custom matrix x vector kernel
693693 switch (tsrc0) {
694694 case GGML_TYPE_F32:
695- {
696- GGML_ASSERT (op->src [1 ]->type == GGML_TYPE_F32);
697- nsg = 1 ;
698- nr0 = 1 ;
699- } break ;
700695 case GGML_TYPE_F16:
701- {
702- GGML_ASSERT (op->src [1 ]->type == GGML_TYPE_F32);
703- nsg = 1 ;
704- nr0 = 1 ;
705- } break ;
706696 case GGML_TYPE_BF16:
707697 {
708- GGML_ASSERT (op->src [1 ]->type == GGML_TYPE_F32);
709- nsg = 1 ;
710- nr0 = 1 ;
698+ if (ne00 % 4 == 0 ) {
699+ nsg = N_SG_F;
700+ nr0 = N_R0_F;
701+ nr1 = 1 ;
702+ smem = 32 *sizeof (float )*N_R0_F;
703+ suffix = " _4" ;
704+ } else {
705+ nsg = N_SG_F;
706+ nr0 = N_R0_F;
707+ nr1 = 1 ;
708+ smem = 32 *sizeof (float )*N_R0_F;
709+ }
711710 } break ;
712711 case GGML_TYPE_Q4_0:
713712 {
@@ -824,15 +823,21 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
824823 }
825824 };
826825
827- snprintf (base, 256 , " kernel_mul_mv_id_%s_%s" , ggml_type_name (tsrc0), ggml_type_name (tsrc1));
826+ snprintf (base, 256 , " kernel_mul_mv_id_%s_%s%s " , ggml_type_name (tsrc0), ggml_type_name (tsrc1), suffix );
828827 snprintf (name, 256 , " %s" , base);
829828
830829 ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline (lib, name);
831830 if (res) {
832831 return res;
833832 }
834833
835- res = ggml_metal_library_compile_pipeline (lib, base, name, nullptr );
834+ ggml_metal_cv_t cv = ggml_metal_cv_init ();
835+
836+ ggml_metal_cv_set_int16 (cv, nsg, FC_MUL_MV + 0 );
837+
838+ res = ggml_metal_library_compile_pipeline (lib, base, name, cv);
839+
840+ ggml_metal_cv_free (cv);
836841
837842 ggml_metal_pipeline_set_nr0 (res, nr0);
838843 ggml_metal_pipeline_set_nr1 (res, nr1);
0 commit comments