@@ -34,6 +34,10 @@ ggml_metal_pipelines_t ggml_metal_pipelines_init(void) {
3434}
3535
3636void ggml_metal_pipelines_free (ggml_metal_pipelines_t ppls) {
37+ if (!ppls) {
38+ return ;
39+ }
40+
3741 for (auto it = ppls->data .begin (); it != ppls->data .end (); ++it) {
3842 ggml_metal_pipeline_free (it->second );
3943 }
@@ -467,37 +471,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
467471 // use custom matrix x vector kernel
468472 switch (tsrc0) {
469473 case GGML_TYPE_F32:
474+ case GGML_TYPE_F16:
475+ case GGML_TYPE_BF16:
470476 {
471- GGML_ASSERT (op->src [1 ]->type == GGML_TYPE_F32);
472-
473- nsg = 1 ;
474- nr0 = 1 ;
475- nr1 = 4 ;
476477 if (ne00 == 4 ) {
478+ nsg = 1 ;
477479 nr0 = 32 ;
480+ nr1 = 4 ;
478481 suffix = " _c4" ;
479- }
480- } break ;
481- case GGML_TYPE_F16:
482- case GGML_TYPE_BF16:
483- {
484- nsg = 1 ;
485- nr0 = 1 ;
486- if (op->src [1 ]->type == GGML_TYPE_F32) {
487- if (ne00 == 4 ) {
488- nr0 = 32 ;
489- nr1 = 4 ;
490- suffix = " _c4" ;
491- } else if (ne11 * ne12 < 4 ) {
492- suffix = " _1row" ;
493- } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0 ) {
494- suffix = " _l4" ;
495- nr1 = ne11;
496- } else {
497- nr1 = 4 ;
498- }
482+ } else if (ne00 % 4 == 0 ) {
483+ nsg = N_SG_F;
484+ nr0 = N_R0_F;
485+ nr1 = 1 ;
486+ smem = 32 *sizeof (float )*N_R0_F;
487+ suffix = " _4" ;
499488 } else {
500- nr1 = 4 ;
489+ nsg = N_SG_F;
490+ nr0 = N_R0_F;
491+ nr1 = 1 ;
492+ smem = 32 *sizeof (float )*N_R0_F;
501493 }
502494 } break ;
503495 case GGML_TYPE_Q4_0:
@@ -689,25 +681,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
689681 const ggml_type tsrc0 = op->src [0 ]->type ;
690682 const ggml_type tsrc1 = op->src [1 ]->type ;
691683
684+ const char * suffix = " " ;
685+
692686 // use custom matrix x vector kernel
693687 switch (tsrc0) {
694688 case GGML_TYPE_F32:
695- {
696- GGML_ASSERT (op->src [1 ]->type == GGML_TYPE_F32);
697- nsg = 1 ;
698- nr0 = 1 ;
699- } break ;
700689 case GGML_TYPE_F16:
701- {
702- GGML_ASSERT (op->src [1 ]->type == GGML_TYPE_F32);
703- nsg = 1 ;
704- nr0 = 1 ;
705- } break ;
706690 case GGML_TYPE_BF16:
707691 {
708- GGML_ASSERT (op->src [1 ]->type == GGML_TYPE_F32);
709- nsg = 1 ;
710- nr0 = 1 ;
692+ if (ne00 % 4 == 0 ) {
693+ nsg = N_SG_F;
694+ nr0 = N_R0_F;
695+ nr1 = 1 ;
696+ smem = 32 *sizeof (float )*N_R0_F;
697+ suffix = " _4" ;
698+ } else {
699+ nsg = N_SG_F;
700+ nr0 = N_R0_F;
701+ nr1 = 1 ;
702+ smem = 32 *sizeof (float )*N_R0_F;
703+ }
711704 } break ;
712705 case GGML_TYPE_Q4_0:
713706 {
@@ -824,7 +817,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
824817 }
825818 };
826819
827- snprintf (base, 256 , " kernel_mul_mv_id_%s_%s" , ggml_type_name (tsrc0), ggml_type_name (tsrc1));
820+ snprintf (base, 256 , " kernel_mul_mv_id_%s_%s%s " , ggml_type_name (tsrc0), ggml_type_name (tsrc1), suffix );
828821 snprintf (name, 256 , " %s" , base);
829822
830823 ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline (lib, name);
0 commit comments