@@ -1340,7 +1340,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
13401340 // Needs to be kept up to date on shader changes
13411341 const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1 ;
13421342 const uint32_t type_size = device->fp16 ? sizeof (ggml_fp16_t ) : sizeof (float );
1343- const uint32_t warps = warptile[0 ] / device-> subgroup_size ;
1343+ const uint32_t warps = warptile[0 ] / warptile[ 10 ] ;
13441344
13451345 const uint32_t load_bufs = (warptile[1 ] + warptile[2 ]) * (warptile[3 ] + bank_conflict_offset) * type_size;
13461346 const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof (uint32_t ) : 0 ;
@@ -1354,8 +1354,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
13541354
13551355 std::cerr << " ggml_vulkan: Compiling shaders" ;
13561356
1357- // some shaders require the subgroup size to be 16 or larger
1357+ // some shaders have a minimum subgroup size
13581358 const uint32_t subgroup_size_16 = std::max (device->subgroup_size , 16u );
1359+ const uint32_t subgroup_size_32 = std::max (device->subgroup_size , 32u );
13591360
13601361 // mulmat
13611362 std::vector<uint32_t > l_warptile, m_warptile, s_warptile,
@@ -1418,11 +1419,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
14181419
14191420 l_warptile = { 128 , 128 , 128 , 16 , device->subgroup_size * 2 , 64 , 2 , tm_l, tn_l, tk_l, device->subgroup_size };
14201421 m_warptile = { 128 , 64 , 64 , 16 , device->subgroup_size , 32 , 2 , tm_m, tn_m, tk_m, device->subgroup_size };
1421- s_warptile = { subgroup_size_16 , 32 , 32 , 16 , 32 , 32 , 2 , tm_s, tn_s, tk_s, device->subgroup_size };
1422+ s_warptile = { subgroup_size_32 , 32 , 32 , 16 , 32 , 32 , 2 , tm_s, tn_s, tk_s, device->subgroup_size };
14221423
14231424 l_warptile_mmq = { 128 , 128 , 128 , 32 , device->subgroup_size * 2 , 64 , 2 , tm_l, tn_l, tk_l, device->subgroup_size };
14241425 m_warptile_mmq = { 128 , 64 , 64 , 32 , device->subgroup_size , 32 , 2 , tm_m, tn_m, tk_m, device->subgroup_size };
1425- s_warptile_mmq = { subgroup_size_16 , 32 , 32 , 32 , 32 , 32 , 2 , tm_s, tn_s, tk_s, device->subgroup_size };
1426+ s_warptile_mmq = { subgroup_size_32 , 32 , 32 , 32 , 32 , 32 , 2 , tm_s, tn_s, tk_s, device->subgroup_size };
14261427
14271428 l_mmq_wg_denoms = l_wg_denoms = {128 , 128 , 1 };
14281429 m_mmq_wg_denoms = m_wg_denoms = { 64 , 64 , 1 };
@@ -1792,7 +1793,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
17921793 ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_f32_f32 [GGML_TYPE_Q4_K], " mul_mat_vec_q4_k_f32_f32" , mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, " main" , 3 , sizeof (vk_mat_vec_push_constants), {1 , 1 , 1 }, {subgroup_size_16}, 1 , true );
17931794 ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_f32_f32 [GGML_TYPE_Q5_K], " mul_mat_vec_q5_k_f32_f32" , mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, " main" , 3 , sizeof (vk_mat_vec_push_constants), {1 , 1 , 1 }, {subgroup_size_16}, 1 , true );
17941795 ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_f32_f32 [GGML_TYPE_Q6_K], " mul_mat_vec_q6_k_f32_f32" , mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, " main" , 3 , sizeof (vk_mat_vec_push_constants), {1 , 1 , 1 }, {subgroup_size_16}, 1 , true );
1795- ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_f32_f32 [GGML_TYPE_IQ4_NL], " mul_mat_vec_iq4_nl_f32_f32" , mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, " main" , 3 , sizeof (vk_mat_vec_push_constants), {2 , 1 , 1 }, {device-> subgroup_size , 2 }, 1 , true );
1796+ ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_f32_f32 [GGML_TYPE_IQ4_NL], " mul_mat_vec_iq4_nl_f32_f32" , mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, " main" , 3 , sizeof (vk_mat_vec_push_constants), {2 , 1 , 1 }, {subgroup_size_16 , 2 }, 1 , true );
17961797
17971798 ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_f16_f32 [GGML_TYPE_F32 ], " mul_mat_vec_f32_f16_f32" , mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, " main" , 3 , sizeof (vk_mat_vec_push_constants), {2 , 1 , 1 }, {device->subgroup_size , 2 }, 1 );
17981799 ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_f16_f32 [GGML_TYPE_F16 ], " mul_mat_vec_f16_f16_f32" , mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, " main" , 3 , sizeof (vk_mat_vec_push_constants), {2 , 1 , 1 }, {device->subgroup_size , 2 }, 1 );
@@ -1806,7 +1807,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
18061807 ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_f16_f32 [GGML_TYPE_Q4_K], " mul_mat_vec_q4_k_f16_f32" , mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, " main" , 3 , sizeof (vk_mat_vec_push_constants), {1 , 1 , 1 }, {subgroup_size_16}, 1 , true );
18071808 ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_f16_f32 [GGML_TYPE_Q5_K], " mul_mat_vec_q5_k_f16_f32" , mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, " main" , 3 , sizeof (vk_mat_vec_push_constants), {1 , 1 , 1 }, {subgroup_size_16}, 1 , true );
18081809 ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_f16_f32 [GGML_TYPE_Q6_K], " mul_mat_vec_q6_k_f16_f32" , mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, " main" , 3 , sizeof (vk_mat_vec_push_constants), {1 , 1 , 1 }, {subgroup_size_16}, 1 , true );
1809- ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_f16_f32 [GGML_TYPE_IQ4_NL], " mul_mat_vec_iq4_nl_f16_f32" , mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, " main" , 3 , sizeof (vk_mat_vec_push_constants), {2 , 1 , 1 }, {device-> subgroup_size }, 1 , true );
1810+ ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_f16_f32 [GGML_TYPE_IQ4_NL], " mul_mat_vec_iq4_nl_f16_f32" , mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, " main" , 3 , sizeof (vk_mat_vec_push_constants), {2 , 1 , 1 }, {subgroup_size_16, 2 }, 1 , true );
18101811
18111812 ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_id_f32 [GGML_TYPE_F32 ], " mul_mat_vec_id_f32_f32" , mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, " main" , 4 , sizeof (vk_mat_vec_id_push_constants), {2 , 1 , 1 }, {device->subgroup_size , 2 }, 1 );
18121813 ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_id_f32 [GGML_TYPE_F16 ], " mul_mat_vec_id_f16_f32" , mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, " main" , 4 , sizeof (vk_mat_vec_id_push_constants), {2 , 1 , 1 }, {device->subgroup_size , 2 }, 1 );
@@ -1820,7 +1821,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
18201821 ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_id_f32 [GGML_TYPE_Q4_K], " mul_mat_vec_id_q4_k_f32" , mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, " main" , 4 , sizeof (vk_mat_vec_id_push_constants), {1 , 1 , 1 }, {subgroup_size_16}, 1 , true );
18211822 ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_id_f32 [GGML_TYPE_Q5_K], " mul_mat_vec_id_q5_k_f32" , mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, " main" , 4 , sizeof (vk_mat_vec_id_push_constants), {1 , 1 , 1 }, {subgroup_size_16}, 1 , true );
18221823 ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_id_f32 [GGML_TYPE_Q6_K], " mul_mat_vec_id_q6_k_f32" , mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, " main" , 4 , sizeof (vk_mat_vec_id_push_constants), {1 , 1 , 1 }, {subgroup_size_16}, 1 , true );
1823- ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_id_f32 [GGML_TYPE_IQ4_NL], " mul_mat_vec_id_iq4_nl_f32" , mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, " main" , 4 , sizeof (vk_mat_vec_id_push_constants), {2 , 1 , 1 }, {device-> subgroup_size , 2 }, 1 , true );
1824+ ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_id_f32 [GGML_TYPE_IQ4_NL], " mul_mat_vec_id_iq4_nl_f32" , mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, " main" , 4 , sizeof (vk_mat_vec_id_push_constants), {2 , 1 , 1 }, {subgroup_size_16 , 2 }, 1 , true );
18241825
18251826 // dequant shaders
18261827 ggml_vk_create_pipeline (device, device->pipeline_dequant [GGML_TYPE_F32 ], " f32_to_f16" , dequant_f32_len, dequant_f32_data, " main" , 2 , 5 * sizeof (uint32_t ), {256 * 16 , 1 , 1 }, {}, 1 );
0 commit comments