@@ -1870,11 +1870,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
18701870#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
18711871 if (device->coopmat_bf16_support) {
18721872 CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
1873- } else
1874- #endif
1875- {
1876- CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
18771873 }
1874+ #endif
18781875 CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
18791876 CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
18801877 CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
@@ -1899,11 +1896,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
18991896#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
19001897 if (device->coopmat_bf16_support) {
19011898 CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
1902- } else
1903- #endif
1904- {
1905- CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4);
19061899 }
1900+ #endif
19071901 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
19081902 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
19091903 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
@@ -1960,11 +1954,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
19601954#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
19611955 if (device->coopmat_bf16_support) {
19621956 CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, )
1963- } else
1964- #endif
1965- {
1966- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
19671957 }
1958+ #endif
19681959
19691960 if (device->coopmat_acc_f16_support) {
19701961 CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -2016,11 +2007,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
20162007#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
20172008 if (device->coopmat_bf16_support) {
20182009 CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2019- } else
2020- #endif
2021- {
2022- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
20232010 }
2011+ #endif
20242012
20252013 if (device->coopmat_acc_f16_support) {
20262014 CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
@@ -2253,8 +2241,26 @@ static void ggml_vk_load_shaders(vk_device& device) {
22532241 CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
22542242 CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
22552243 CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2256- #undef CREATE_MM
22572244 }
2245+ // reusing CREATE_MM from the fp32 path
2246+ if ((device->coopmat2 || device->coopmat_support)
2247+ #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2248+ && !device->coopmat_bf16_support
2249+ #endif
2250+ ) {
2251+ // use scalar tile sizes
2252+ l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
2253+ m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 };
2254+ s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, 1, subgroup_size_8 };
2255+
2256+ l_wg_denoms = {128, 128, 1 };
2257+ m_wg_denoms = { 64, 64, 1 };
2258+ s_wg_denoms = { 32, 32, 1 };
2259+
2260+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2261+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
2262+ }
2263+ #undef CREATE_MM
22582264
22592265 // mul mat vec
22602266
@@ -2814,7 +2820,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
28142820#if defined(VK_KHR_shader_bfloat16)
28152821 VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
28162822 bfloat16_features.pNext = nullptr;
2817- bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV ;
2823+ bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR ;
28182824 if (bfloat16_support) {
28192825 last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features;
28202826 last_struct = (VkBaseOutStructure *)&bfloat16_features;
0 commit comments