@@ -1924,11 +1924,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
19241924#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
19251925 if (device->coopmat_bf16_support) {
19261926 CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
1927- } else
1928- #endif
1929- {
1930- CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
19311927 }
1928+ #endif
19321929 CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
19331930 CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
19341931 CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
@@ -1953,11 +1950,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
19531950#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
19541951 if (device->coopmat_bf16_support) {
19551952 CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
1956- } else
1957- #endif
1958- {
1959- CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4);
19601953 }
1954+ #endif
19611955 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
19621956 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
19631957 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
@@ -2014,11 +2008,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
20142008#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
20152009 if (device->coopmat_bf16_support) {
20162010 CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, )
2017- } else
2018- #endif
2019- {
2020- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
20212011 }
2012+ #endif
20222013
20232014 if (device->coopmat_acc_f16_support) {
20242015 CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -2070,11 +2061,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
20702061#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
20712062 if (device->coopmat_bf16_support) {
20722063 CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2073- } else
2074- #endif
2075- {
2076- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
20772064 }
2065+ #endif
20782066
20792067 if (device->coopmat_acc_f16_support) {
20802068 CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
@@ -2307,8 +2295,26 @@ static void ggml_vk_load_shaders(vk_device& device) {
23072295 CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
23082296 CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
23092297 CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2310- #undef CREATE_MM
23112298 }
2299+ // reusing CREATE_MM from the fp32 path
2300+ if ((device->coopmat2 || device->coopmat_support)
2301+ #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2302+ && !device->coopmat_bf16_support
2303+ #endif
2304+ ) {
2305+ // use scalar tile sizes
2306+ l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
2307+ m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 };
2308+ s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, 1, subgroup_size_8 };
2309+
2310+ l_wg_denoms = {128, 128, 1 };
2311+ m_wg_denoms = { 64, 64, 1 };
2312+ s_wg_denoms = { 32, 32, 1 };
2313+
2314+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2315+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
2316+ }
2317+ #undef CREATE_MM
23122318
23132319 // mul mat vec
23142320
@@ -2869,7 +2875,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
28692875#if defined(VK_KHR_shader_bfloat16)
28702876 VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
28712877 bfloat16_features.pNext = nullptr;
2872- bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV ;
2878+ bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR ;
28732879 if (bfloat16_support) {
28742880 last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features;
28752881 last_struct = (VkBaseOutStructure *)&bfloat16_features;
0 commit comments