Skip to content

Commit 2917cad

Browse files
committed
vulkan: bfloat16 fixes (really works without bfloat16 support now)
1 parent 60b5d31 commit 2917cad

File tree

3 files changed

+33
-23
lines changed

3 files changed

+33
-23
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1924,11 +1924,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
19241924
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
19251925
if (device->coopmat_bf16_support) {
19261926
CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
1927-
} else
1928-
#endif
1929-
{
1930-
CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
19311927
}
1928+
#endif
19321929
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
19331930
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
19341931
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
@@ -1953,11 +1950,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
19531950
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
19541951
if (device->coopmat_bf16_support) {
19551952
CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
1956-
} else
1957-
#endif
1958-
{
1959-
CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4);
19601953
}
1954+
#endif
19611955
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
19621956
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
19631957
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
@@ -2014,11 +2008,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
20142008
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
20152009
if (device->coopmat_bf16_support) {
20162010
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, )
2017-
} else
2018-
#endif
2019-
{
2020-
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
20212011
}
2012+
#endif
20222013

20232014
if (device->coopmat_acc_f16_support) {
20242015
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -2070,11 +2061,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
20702061
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
20712062
if (device->coopmat_bf16_support) {
20722063
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2073-
} else
2074-
#endif
2075-
{
2076-
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
20772064
}
2065+
#endif
20782066

20792067
if (device->coopmat_acc_f16_support) {
20802068
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
@@ -2307,8 +2295,26 @@ static void ggml_vk_load_shaders(vk_device& device) {
23072295
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
23082296
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
23092297
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2310-
#undef CREATE_MM
23112298
}
2299+
// reusing CREATE_MM from the fp32 path
2300+
if ((device->coopmat2 || device->coopmat_support)
2301+
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2302+
&& !device->coopmat_bf16_support
2303+
#endif
2304+
) {
2305+
// use scalar tile sizes
2306+
l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
2307+
m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 };
2308+
s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, 1, subgroup_size_8 };
2309+
2310+
l_wg_denoms = {128, 128, 1 };
2311+
m_wg_denoms = { 64, 64, 1 };
2312+
s_wg_denoms = { 32, 32, 1 };
2313+
2314+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2315+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
2316+
}
2317+
#undef CREATE_MM
23122318

23132319
// mul mat vec
23142320

@@ -2869,7 +2875,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
28692875
#if defined(VK_KHR_shader_bfloat16)
28702876
VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
28712877
bfloat16_features.pNext = nullptr;
2872-
bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV;
2878+
bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR;
28732879
if (bfloat16_support) {
28742880
last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features;
28752881
last_struct = (VkBaseOutStructure *)&bfloat16_features;

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
1111
#endif
1212

13-
#ifdef DATA_A_BF16
13+
#if defined(DATA_A_BF16) && defined(COOPMAT)
1414
#extension GL_EXT_bfloat16 : enable
1515
#endif
1616

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,6 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
340340
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
341341
string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
342342

343-
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
344343
// bf16
345344
{
346345
std::string load_vec_a_unaligned = "1";
@@ -350,10 +349,15 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
350349
// scalar path promotes to float
351350
std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32";
352351

353-
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "uint16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
354-
string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
355-
}
352+
// If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader
353+
#if !defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
354+
if (!(coopmat || coopmat2))
356355
#endif
356+
{
357+
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "uint16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
358+
string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
359+
}
360+
}
357361

358362
for (const auto& tname : type_names) {
359363
std::string load_vec_quant = "2";

0 commit comments

Comments
 (0)