Skip to content

Commit 06d4f77

Browse files
committed
vulkan: bfloat16 fixes (really works without bfloat16 support now)
1 parent 21e8793 commit 06d4f77

File tree

4 files changed

+45
-39
lines changed

4 files changed

+45
-39
lines changed

ggml/src/ggml-vulkan/CMakeLists.txt

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -92,23 +92,19 @@ if (Vulkan_FOUND)
9292
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
9393
endif()
9494

95-
if(NOT DEFINED GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
96-
# Compile a test shader to determine whether GL_EXT_bfloat16 is supported.
97-
# If it's not, there will be an error to stderr.
98-
# If it's supported, set a define to indicate that we should compile those shaders
99-
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp"
100-
OUTPUT_VARIABLE glslc_output
101-
ERROR_VARIABLE glslc_error)
102-
103-
if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_bfloat16.*")
104-
message(STATUS "GL_EXT_bfloat16 not supported by glslc")
105-
set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT OFF CACHE INTERNAL "Whether bfloat16 is supported by glslc")
106-
else()
107-
message(STATUS "GL_EXT_bfloat16 supported by glslc")
108-
set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT ON CACHE INTERNAL "Whether bfloat16 is supported by glslc")
109-
endif()
95+
# Compile a test shader to determine whether GL_EXT_bfloat16 is supported.
96+
# If it's not, there will be an error to stderr.
97+
# If it's supported, set a define to indicate that we should compile those shaders
98+
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp"
99+
OUTPUT_VARIABLE glslc_output
100+
ERROR_VARIABLE glslc_error)
101+
102+
if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_bfloat16.*")
103+
message(STATUS "GL_EXT_bfloat16 not supported by glslc")
104+
set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT OFF CACHE INTERNAL "Whether bfloat16 is supported by glslc")
110105
else()
111-
message(STATUS "GL_EXT_bfloat16 support already defined: ${GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT}")
106+
message(STATUS "GL_EXT_bfloat16 supported by glslc")
107+
set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT ON CACHE INTERNAL "Whether bfloat16 is supported by glslc")
112108
endif()
113109

114110
if(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1870,11 +1870,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
18701870
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
18711871
if (device->coopmat_bf16_support) {
18721872
CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
1873-
} else
1874-
#endif
1875-
{
1876-
CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
18771873
}
1874+
#endif
18781875
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
18791876
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
18801877
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
@@ -1899,11 +1896,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
18991896
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
19001897
if (device->coopmat_bf16_support) {
19011898
CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
1902-
} else
1903-
#endif
1904-
{
1905-
CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4);
19061899
}
1900+
#endif
19071901
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
19081902
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
19091903
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
@@ -1960,11 +1954,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
19601954
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
19611955
if (device->coopmat_bf16_support) {
19621956
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, )
1963-
} else
1964-
#endif
1965-
{
1966-
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
19671957
}
1958+
#endif
19681959

19691960
if (device->coopmat_acc_f16_support) {
19701961
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -2016,11 +2007,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
20162007
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
20172008
if (device->coopmat_bf16_support) {
20182009
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2019-
} else
2020-
#endif
2021-
{
2022-
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
20232010
}
2011+
#endif
20242012

20252013
if (device->coopmat_acc_f16_support) {
20262014
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
@@ -2253,8 +2241,26 @@ static void ggml_vk_load_shaders(vk_device& device) {
22532241
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
22542242
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
22552243
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2256-
#undef CREATE_MM
22572244
}
2245+
// reusing CREATE_MM from the fp32 path
2246+
if ((device->coopmat2 || device->coopmat_support)
2247+
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2248+
&& !device->coopmat_bf16_support
2249+
#endif
2250+
) {
2251+
// use scalar tile sizes
2252+
l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
2253+
m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 };
2254+
s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, 1, subgroup_size_8 };
2255+
2256+
l_wg_denoms = {128, 128, 1 };
2257+
m_wg_denoms = { 64, 64, 1 };
2258+
s_wg_denoms = { 32, 32, 1 };
2259+
2260+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2261+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
2262+
}
2263+
#undef CREATE_MM
22582264

22592265
// mul mat vec
22602266

@@ -2814,7 +2820,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
28142820
#if defined(VK_KHR_shader_bfloat16)
28152821
VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
28162822
bfloat16_features.pNext = nullptr;
2817-
bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV;
2823+
bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR;
28182824
if (bfloat16_support) {
28192825
last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features;
28202826
last_struct = (VkBaseOutStructure *)&bfloat16_features;

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
1111
#endif
1212

13-
#ifdef DATA_A_BF16
13+
#if defined(DATA_A_BF16) && defined(COOPMAT)
1414
#extension GL_EXT_bfloat16 : enable
1515
#endif
1616

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,6 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
340340
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
341341
string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
342342

343-
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
344343
// bf16
345344
{
346345
std::string load_vec_a_unaligned = "1";
@@ -350,10 +349,15 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
350349
// scalar path promotes to float
351350
std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32";
352351

353-
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "uint16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
354-
string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
355-
}
352+
// If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader
353+
#if !defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
354+
if (!(coopmat || coopmat2))
356355
#endif
356+
{
357+
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "uint16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
358+
string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
359+
}
360+
}
357361

358362
for (const auto& tname : type_names) {
359363
std::string load_vec_quant = "2";

0 commit comments

Comments
 (0)