@@ -1476,26 +1476,26 @@ static void ggml_vk_load_shaders(vk_device& device) {
14761476        //  spec constants and tile sizes for quant matmul (non-Qi_K)
14771477        l_warptile_mmq = { 256 , 128 , 256 , 64  };
14781478        m_warptile_mmq = { 256 , 128 , 128 , 64  };
1479-         s_warptile_mmq = { 256 , 128 ,  128 ,  64  };
1479+         s_warptile_mmq = { 256 , 32 ,   64 ,  128  };
14801480        l_mmq_wg_denoms = { 128 , 256 , 1  };
14811481        m_mmq_wg_denoms = { 128 , 128 , 1  };
1482-         s_mmq_wg_denoms = { 128 ,  128 ,  1  };
1482+         s_mmq_wg_denoms = { 32 ,   64 ,   1  };
14831483
14841484        //  spec constants and tile sizes for quant matmul (Qi_K)
1485-         l_warptile_mmq_k = { 256 , 128 ,  512 ,  16  };
1486-         m_warptile_mmq_k = { 256 , 128 ,  256 ,  16  };
1487-         s_warptile_mmq_k = { 256 , 32 , 128 ,  64  };
1488-         l_mmq_wg_denoms_k = { 128 ,  512 , 1  };
1489-         m_mmq_wg_denoms_k = { 128 ,  256 , 1  };
1490-         s_mmq_wg_denoms_k = { 32 , 128 , 1  };
1485+         l_warptile_mmq_k = { 256 , 64 ,  128 ,  64  };
1486+         m_warptile_mmq_k = { 256 , 32 ,   64 ,  64  };
1487+         s_warptile_mmq_k = { 256 , 32 ,   32 ,  128  };
1488+         l_mmq_wg_denoms_k = { 64 ,  128 , 1  };
1489+         m_mmq_wg_denoms_k = { 32 ,   64 , 1  };
1490+         s_mmq_wg_denoms_k = { 32 ,   32 , 1  };
14911491
14921492        //  spec constants and tile sizes for quant matmul_id
1493-         l_warptile_mmqid = { 256 , 128 , 128 , 16  };
1493+         l_warptile_mmqid = { 256 , 128 , 64 , 16  };
14941494        m_warptile_mmqid = { 256 , 128 , 64 , 16  };
1495-         s_warptile_mmqid = { 256 , 64 , 64 , 16  };
1496-         l_mmqid_wg_denoms = { 128 , 128 , 1  };
1495+         s_warptile_mmqid = { 256 , 128 , 64 , 16  };
1496+         l_mmqid_wg_denoms = { 128 , 64 , 1  };
14971497        m_mmqid_wg_denoms = { 128 , 64 , 1  };
1498-         s_mmqid_wg_denoms = { 64 , 64 , 1  };
1498+         s_mmqid_wg_denoms = { 128 , 64 , 1  };
14991499
15001500        l_align = 128 ;
15011501        m_align =  64 ;
@@ -3850,10 +3850,14 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
38503850    VK_LOG_DEBUG (" ggml_vk_guess_matmul_pipeline("   << m << " , "   << n << " , "   << aligned << " , "   << ggml_type_name (src0_type) << " )"  );
38513851
38523852    if  (ctx->device ->coopmat2 ) {
3853-         if  ((ctx->device ->mul_mat_l [src0_type] && (m % mmp->l ->wg_denoms [0 ]) == 0  && (n % mmp->l ->wg_denoms [1 ]) == 0 ) || (!ctx->device ->mul_mat_m [src0_type] && !ctx->device ->mul_mat_s [src0_type])) {
3853+         //  Use large shader when the N dimension is greater than the medium shader's tile size
3854+         uint32_t  crossover_large = mmp->m ->wg_denoms [1 ];
3855+         if  ((ctx->device ->mul_mat_l [src0_type] && (n > crossover_large)) || (!ctx->device ->mul_mat_m [src0_type] && !ctx->device ->mul_mat_s [src0_type])) {
38543856            return  aligned ? mmp->a_l  : mmp->l ;
38553857        }
3856-         if  ((ctx->device ->mul_mat_m [src0_type] && (m % mmp->m ->wg_denoms [0 ]) == 0  && (n % mmp->m ->wg_denoms [1 ]) == 0 ) || !ctx->device ->mul_mat_s [src0_type]) {
3858+         //  Use medium shader when the N dimension is greater than the small shader's tile size
3859+         uint32_t  crossover_medium = mmp->s ->wg_denoms [1 ];
3860+         if  ((ctx->device ->mul_mat_m [src0_type] && (n > crossover_medium)) || !ctx->device ->mul_mat_s [src0_type]) {
38573861            return  aligned ? mmp->a_m  : mmp->m ;
38583862        }
38593863        return  aligned ? mmp->a_s  : mmp->s ;
@@ -3898,13 +3902,17 @@ static void ggml_vk_matmul(
38983902}
38993903
39003904static  vk_pipeline ggml_vk_guess_matmul_id_pipeline (ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int  m, int  n, bool  aligned, ggml_type src0_type) {
3901-     VK_LOG_DEBUG (" ggml_vk_guess_matmul_pipeline ("   << m << " , "   << n << " , "   << aligned << " , "   << ggml_type_name (src0_type) << " )"  );
3905+     VK_LOG_DEBUG (" ggml_vk_guess_matmul_id_pipeline ("   << m << " , "   << n << " , "   << aligned << " , "   << ggml_type_name (src0_type) << " )"  );
39023906
39033907    if  (ctx->device ->coopmat2 ) {
3904-         if  ((ctx->device ->mul_mat_id_l [src0_type] && (m % mmp->l ->wg_denoms [0 ]) == 0  && (n % mmp->l ->wg_denoms [1 ]) == 0 ) || (!ctx->device ->mul_mat_id_m [src0_type] && !ctx->device ->mul_mat_id_s [src0_type])) {
3908+         //  Use large shader when the N dimension is greater than the medium shader's tile size
3909+         uint32_t  crossover_large = mmp->m ->wg_denoms [1 ];
3910+         if  ((ctx->device ->mul_mat_id_l [src0_type] && (n > crossover_large)) || (!ctx->device ->mul_mat_id_m [src0_type] && !ctx->device ->mul_mat_id_s [src0_type])) {
39053911            return  aligned ? mmp->a_l  : mmp->l ;
39063912        }
3907-         if  ((ctx->device ->mul_mat_id_m [src0_type] && (m % mmp->m ->wg_denoms [0 ]) == 0  && (n % mmp->m ->wg_denoms [1 ]) == 0 ) || !ctx->device ->mul_mat_id_s [src0_type]) {
3913+         //  Use medium shader when the N dimension is greater than the small shader's tile size
3914+         uint32_t  crossover_medium = mmp->s ->wg_denoms [1 ];
3915+         if  ((ctx->device ->mul_mat_id_m [src0_type] && (n > crossover_medium)) || !ctx->device ->mul_mat_id_s [src0_type]) {
39083916            return  aligned ? mmp->a_m  : mmp->m ;
39093917        }
39103918        return  aligned ? mmp->a_s  : mmp->s ;
0 commit comments