@@ -1474,28 +1474,28 @@ static void ggml_vk_load_shaders(vk_device& device) {
14741474 s_wg_denoms = { 64 , 64 , 1 };
14751475
14761476 // spec constants and tile sizes for quant matmul (non-Qi_K)
1477- l_warptile_mmq = { 256 , 128 , 256 , 64 };
1478- m_warptile_mmq = { 256 , 128 , 128 , 64 };
1479- s_warptile_mmq = { 256 , 128 , 128 , 64 };
1480- l_mmq_wg_denoms = { 128 , 256 , 1 };
1481- m_mmq_wg_denoms = { 128 , 128 , 1 };
1482- s_mmq_wg_denoms = { 128 , 128 , 1 };
1477+ l_warptile_mmq = { 256 , 64 , 128 , 64 };
1478+ m_warptile_mmq = { 256 , 32 , 64 , 64 };
1479+ s_warptile_mmq = { 256 , 32 , 32 , 128 };
1480+ l_mmq_wg_denoms = { 64 , 128 , 1 };
1481+ m_mmq_wg_denoms = { 32 , 64 , 1 };
1482+ s_mmq_wg_denoms = { 32 , 32 , 1 };
14831483
14841484 // spec constants and tile sizes for quant matmul (Qi_K)
1485- l_warptile_mmq_k = { 256 , 128 , 512 , 16 };
1486- m_warptile_mmq_k = { 256 , 128 , 256 , 16 };
1487- s_warptile_mmq_k = { 256 , 32 , 128 , 64 };
1488- l_mmq_wg_denoms_k = { 128 , 512 , 1 };
1489- m_mmq_wg_denoms_k = { 128 , 256 , 1 };
1490- s_mmq_wg_denoms_k = { 32 , 128 , 1 };
1485+ l_warptile_mmq_k = { 256 , 64 , 128 , 64 };
1486+ m_warptile_mmq_k = { 256 , 32 , 64 , 64 };
1487+ s_warptile_mmq_k = { 256 , 32 , 32 , 128 };
1488+ l_mmq_wg_denoms_k = { 64 , 128 , 1 };
1489+ m_mmq_wg_denoms_k = { 32 , 64 , 1 };
1490+ s_mmq_wg_denoms_k = { 32 , 32 , 1 };
14911491
14921492 // spec constants and tile sizes for quant matmul_id
1493- l_warptile_mmqid = { 256 , 128 , 128 , 16 };
1493+ l_warptile_mmqid = { 256 , 128 , 64 , 16 };
14941494 m_warptile_mmqid = { 256 , 128 , 64 , 16 };
1495- s_warptile_mmqid = { 256 , 64 , 64 , 16 };
1496- l_mmqid_wg_denoms = { 128 , 128 , 1 };
1495+ s_warptile_mmqid = { 256 , 128 , 64 , 16 };
1496+ l_mmqid_wg_denoms = { 128 , 64 , 1 };
14971497 m_mmqid_wg_denoms = { 128 , 64 , 1 };
1498- s_mmqid_wg_denoms = { 64 , 64 , 1 };
1498+ s_mmqid_wg_denoms = { 128 , 64 , 1 };
14991499
15001500 l_align = 128 ;
15011501 m_align = 64 ;
@@ -3850,10 +3850,14 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
38503850 VK_LOG_DEBUG (" ggml_vk_guess_matmul_pipeline(" << m << " , " << n << " , " << aligned << " , " << ggml_type_name (src0_type) << " )" );
38513851
38523852 if (ctx->device ->coopmat2 ) {
3853- if ((ctx->device ->mul_mat_l [src0_type] && (m % mmp->l ->wg_denoms [0 ]) == 0 && (n % mmp->l ->wg_denoms [1 ]) == 0 ) || (!ctx->device ->mul_mat_m [src0_type] && !ctx->device ->mul_mat_s [src0_type])) {
3853+ // Use large shader when the N dimension is greater than the medium shader's tile size
3854+ uint32_t crossover_large = mmp->m ->wg_denoms [1 ];
3855+ if ((ctx->device ->mul_mat_l [src0_type] && (n > crossover_large)) || (!ctx->device ->mul_mat_m [src0_type] && !ctx->device ->mul_mat_s [src0_type])) {
38543856 return aligned ? mmp->a_l : mmp->l ;
38553857 }
3856- if ((ctx->device ->mul_mat_m [src0_type] && (m % mmp->m ->wg_denoms [0 ]) == 0 && (n % mmp->m ->wg_denoms [1 ]) == 0 ) || !ctx->device ->mul_mat_s [src0_type]) {
3858+ // Use medium shader when the N dimension is greater than the small shader's tile size
3859+ uint32_t crossover_medium = mmp->s ->wg_denoms [1 ];
3860+ if ((ctx->device ->mul_mat_m [src0_type] && (n > crossover_medium)) || !ctx->device ->mul_mat_s [src0_type]) {
38573861 return aligned ? mmp->a_m : mmp->m ;
38583862 }
38593863 return aligned ? mmp->a_s : mmp->s ;
@@ -3898,13 +3902,17 @@ static void ggml_vk_matmul(
38983902}
38993903
39003904static vk_pipeline ggml_vk_guess_matmul_id_pipeline (ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
3901- VK_LOG_DEBUG (" ggml_vk_guess_matmul_pipeline (" << m << " , " << n << " , " << aligned << " , " << ggml_type_name (src0_type) << " )" );
3905+ VK_LOG_DEBUG (" ggml_vk_guess_matmul_id_pipeline (" << m << " , " << n << " , " << aligned << " , " << ggml_type_name (src0_type) << " )" );
39023906
39033907 if (ctx->device ->coopmat2 ) {
3904- if ((ctx->device ->mul_mat_id_l [src0_type] && (m % mmp->l ->wg_denoms [0 ]) == 0 && (n % mmp->l ->wg_denoms [1 ]) == 0 ) || (!ctx->device ->mul_mat_id_m [src0_type] && !ctx->device ->mul_mat_id_s [src0_type])) {
3908+ // Use large shader when the N dimension is greater than the medium shader's tile size
3909+ uint32_t crossover_large = mmp->m ->wg_denoms [1 ];
3910+ if ((ctx->device ->mul_mat_id_l [src0_type] && (n > crossover_large)) || (!ctx->device ->mul_mat_id_m [src0_type] && !ctx->device ->mul_mat_id_s [src0_type])) {
39053911 return aligned ? mmp->a_l : mmp->l ;
39063912 }
3907- if ((ctx->device ->mul_mat_id_m [src0_type] && (m % mmp->m ->wg_denoms [0 ]) == 0 && (n % mmp->m ->wg_denoms [1 ]) == 0 ) || !ctx->device ->mul_mat_id_s [src0_type]) {
3913+ // Use medium shader when the N dimension is greater than the small shader's tile size
3914+ uint32_t crossover_medium = mmp->s ->wg_denoms [1 ];
3915+ if ((ctx->device ->mul_mat_id_m [src0_type] && (n > crossover_medium)) || !ctx->device ->mul_mat_id_s [src0_type]) {
39083916 return aligned ? mmp->a_m : mmp->m ;
39093917 }
39103918 return aligned ? mmp->a_s : mmp->s ;
0 commit comments