@@ -1430,6 +1430,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
14301430 VK_LOG_DEBUG (" ggml_vk_load_shaders(" << device->name << " )" );
14311431
14321432 // some shaders have a minimum subgroup size
1433+ const uint32_t subgroup_size_8 = std::max (device->subgroup_size , 8u );
14331434 const uint32_t subgroup_size_16 = std::max (device->subgroup_size , 16u );
14341435 const uint32_t subgroup_size_32 = std::max (device->subgroup_size , 32u );
14351436
@@ -1492,13 +1493,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
14921493 const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1 ;
14931494 const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1 ;
14941495
1495- l_warptile = { 128 , 128 , 128 , 16 , device-> subgroup_size * 2 , 64 , 2 , tm_l, tn_l, tk_l, device-> subgroup_size };
1496- m_warptile = { 128 , 64 , 64 , 16 , device-> subgroup_size , 32 , 2 , tm_m, tn_m, tk_m, device-> subgroup_size };
1497- s_warptile = { subgroup_size_16, 32 , 32 , 16 , 32 , 32 , 2 , tm_s, tn_s, tk_s, device-> subgroup_size };
1496+ l_warptile = { 128 , 128 , 128 , 16 , subgroup_size_8 * 2 , 64 , 2 , tm_l, tn_l, tk_l, subgroup_size_8 };
1497+ m_warptile = { 128 , 64 , 64 , 16 , subgroup_size_8 , 32 , 2 , tm_m, tn_m, tk_m, subgroup_size_8 };
1498+ s_warptile = { subgroup_size_16, 32 , 32 , 16 , 32 , 32 , 2 , tm_s, tn_s, tk_s, subgroup_size_8 };
14981499
1499- l_warptile_mmq = { 128 , 128 , 128 , 32 , device-> subgroup_size * 2 , 64 , 2 , tm_l, tn_l, tk_l, device-> subgroup_size };
1500- m_warptile_mmq = { 128 , 64 , 64 , 32 , device-> subgroup_size , 32 , 2 , tm_m, tn_m, tk_m, device-> subgroup_size };
1501- s_warptile_mmq = { subgroup_size_32, 32 , 32 , 32 , 32 , 32 , 2 , tm_s, tn_s, tk_s, device-> subgroup_size };
1500+ l_warptile_mmq = { 128 , 128 , 128 , 32 , subgroup_size_8 * 2 , 64 , 2 , tm_l, tn_l, tk_l, subgroup_size_8 };
1501+ m_warptile_mmq = { 128 , 64 , 64 , 32 , subgroup_size_8 , 32 , 2 , tm_m, tn_m, tk_m, subgroup_size_8 };
1502+ s_warptile_mmq = { subgroup_size_32, 32 , 32 , 32 , 32 , 32 , 2 , tm_s, tn_s, tk_s, subgroup_size_8 };
15021503
15031504 l_mmq_wg_denoms = l_wg_denoms = {128 , 128 , 1 };
15041505 m_mmq_wg_denoms = m_wg_denoms = { 64 , 64 , 1 };
0 commit comments